示例#1
0
def ax_cube(objective, n_trials, n_dim, with_count=False, method=None):
    global feval_count
    feval_count = 0
    rt = get_logger('ax')
    rt.setLevel(CRITICAL)

    def evaluation_func(prms):
        global feval_count
        feval_count += 1
        return objective([prms["u" + str(i)] for i in range(n_dim)])

    parameters = [{
        "name": "u" + str(i),
        "type": "range",
        "bounds": [0.0, 1.0],
    } for i in range(n_dim)]
    best_parameters, best_values, experiment, model = optimize(
        parameters=parameters,
        evaluation_function=evaluation_func,
        minimize=True,
        total_trials=n_trials)
    best_x = [best_parameters['u' + str(i)] for i in range(n_dim)]
    best_val = best_values[0]['objective']
    return (best_val, best_x, feval_count) if with_count else (best_val,
                                                               best_x)
示例#2
0
 def testLoggerWithFile(self):
     with NamedTemporaryFile() as tf:
         logger = get_logger(BASE_LOGGER_NAME + ".testLoggerWithFile")
         logger.addHandler(build_file_handler(tf.name))
         logger.info(self.warning_string)
         output = str(tf.read())
         self.assertIn(BASE_LOGGER_NAME, output)
         self.assertIn(self.warning_string, output)
         tf.close()
示例#3
0
 def testLogger(self):
     logger = get_logger(__name__)
     patcher = patch.object(logger, "warning")
     mock_warning = patcher.start()
     logger.warning(self.warning_string)
     mock_warning.assert_called_once_with(self.warning_string)
     # Need to stop patcher, else in some environments (like pytest)
     # the mock will leak into other tests, since it's getting set
     # onto the python logger directly.
     patcher.stop()
示例#4
0
 def testLoggerOutputNameWithFile(self):
     with NamedTemporaryFile() as tf:
         logger = get_logger(BASE_LOGGER_NAME + ".testLoggerOutputNameWithFile")
         logger.addHandler(build_file_handler(tf.name))
         logger = logging.LoggerAdapter(logger, {"output_name": "my_output_name"})
         logger.warning(self.warning_string)
         output = str(tf.read())
         self.assertIn("my_output_name", output)
         self.assertIn(self.warning_string, output)
         tf.close()
示例#5
0
 def testLogger(self):
     logger = get_logger(BASE_LOGGER_NAME + ".testLogger")
     # Verify it doesn't crash
     logger.warning(self.warning_string)
     # Patch it, verify we actually called it
     patcher = patch.object(logger, "warning")
     mock_warning = patcher.start()
     logger.warning(self.warning_string)
     mock_warning.assert_called_once_with(self.warning_string)
     # Need to stop patcher, else in some environments (like pytest)
     # the mock will leak into other tests, since it's getting set
     # onto the python logger directly.
     patcher.stop()
示例#6
0
    def test_validate_kwarg_typing(self):
        def typed_callable(arg1: int, arg2: str = None) -> None:
            pass

        def typed_callable_with_dict(arg3: int, arg4: Dict[str, int]) -> None:
            pass

        def typed_callable_valid(arg3: int, arg4: str = None) -> None:
            pass

        def typed_callable_dup_keyword(arg2: int, arg4: str = None) -> None:
            pass

        def typed_callable_with_callable(
                arg1: int, arg2: Callable[[int], Dict[str, int]]) -> None:
            pass

        def typed_callable_extra_arg(arg1: int, arg2: str, arg3: bool) -> None:
            pass

        # pass
        try:
            kwargs = {"arg1": 1, "arg2": "test", "arg3": 2}
            validate_kwarg_typing([typed_callable, typed_callable_valid],
                                  **kwargs)
        except Exception:
            self.assertTrue(False, "Exception raised on valid kwargs")

        # pass with complex data structure
        try:
            kwargs = {"arg1": 1, "arg2": "test", "arg3": 2, "arg4": {"k1": 1}}
            validate_kwarg_typing([typed_callable, typed_callable_with_dict],
                                  **kwargs)
        except Exception:
            self.assertTrue(False, "Exception raised on valid kwargs")

        # callable as arg (same arg count but diff type)
        try:
            kwargs = {"arg1": 1, "arg2": typed_callable}
            validate_kwarg_typing([typed_callable_with_callable], **kwargs)
        except Exception:
            self.assertTrue(False, "Exception raised on valid kwargs")

        # callable as arg (diff arg count)
        try:
            kwargs = {"arg1": 1, "arg2": typed_callable_extra_arg}
            validate_kwarg_typing([typed_callable_with_callable], **kwargs)
        except Exception:
            self.assertTrue(False, "Exception raised on valid kwargs")

        # kwargs contains extra keywords
        with self.assertRaises(ValueError):
            kwargs = {"arg1": 1, "arg2": "test", "arg3": 3, "arg5": 4}
            typed_callables = [typed_callable, typed_callable_valid]
            validate_kwarg_typing(typed_callables, **kwargs)

        # callables have duplicate keywords
        logger = get_logger("ax.utils.common.kwargs")
        with patch.object(logger, "debug") as mock_debug:
            kwargs = {"arg1": 1, "arg2": "test", "arg4": "test_again"}
            typed_callables = [typed_callable, typed_callable_dup_keyword]
            validate_kwarg_typing(typed_callables, **kwargs)
            mock_debug.assert_called_once_with(
                f"`{typed_callables}` have duplicate keyword argument: arg2.")

        # mismatch types
        with patch.object(logger, "warning") as mock_warning:
            kwargs = {"arg1": 1, "arg2": "test", "arg3": "test_again"}
            typed_callables = [typed_callable, typed_callable_valid]
            validate_kwarg_typing(typed_callables, **kwargs)
            expected_message = (
                f"`{typed_callable_valid}` expected argument `arg3` to be of type"
                f" {type(1)}. Got test_again (type: {type('test_again')}).")
            mock_warning.assert_called_once_with(expected_message)

        # mismatch types with Dict
        with patch.object(logger, "warning") as mock_warning:
            str_dic = {"k1": "test"}
            kwargs = {"arg1": 1, "arg2": "test", "arg3": 2, "arg4": str_dic}
            typed_callables = [typed_callable, typed_callable_with_dict]
            validate_kwarg_typing(typed_callables, **kwargs)
            expected_message = (
                f"`{typed_callable_with_dict}` expected argument `arg4` to be of type"
                f" typing.Dict[str, int]. Got {str_dic} (type: {type(str_dic)})."
            )
            mock_warning.assert_called_once_with(expected_message)

        # mismatch types with callable as arg
        with patch.object(logger, "warning") as mock_warning:
            kwargs = {"arg1": 1, "arg2": "test_again"}
            typed_callables = [typed_callable_with_callable]
            validate_kwarg_typing(typed_callables, **kwargs)
            expected_message = (
                f"`{typed_callable_with_callable}` expected argument `arg2` to be of"
                f" type typing.Callable[[int], typing.Dict[str, int]]. "
                f"Got test_again (type: {type('test_again')}).")
            mock_warning.assert_called_once_with(expected_message)
示例#7
0
from ax import optimize
from logging import CRITICAL
from ax.utils.common.logger import get_logger
rt = get_logger('ax')
rt.setLevel(CRITICAL)
import warnings
warnings.filterwarnings("ignore")
from funcy import print_durations
from humpday.objectives.classic import CLASSIC_OBJECTIVES


def ax_cube(objective, n_trials, n_dim, with_count=False, method=None):
    global feval_count
    feval_count = 0
    rt = get_logger('ax')
    rt.setLevel(CRITICAL)

    def evaluation_func(prms):
        global feval_count
        feval_count += 1
        return objective([prms["u" + str(i)] for i in range(n_dim)])

    parameters = [{
        "name": "u" + str(i),
        "type": "range",
        "bounds": [0.0, 1.0],
    } for i in range(n_dim)]
    best_parameters, best_values, experiment, model = optimize(
        parameters=parameters,
        evaluation_function=evaluation_func,
        minimize=True,
示例#8
0
from collections import defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Tuple, Union

import numpy as np
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig, TParamValue
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.logger import get_logger

if TYPE_CHECKING:
    # import as module to make sphinx-autodoc-typehints happy
    from ax.modelbridge import base as base_modelbridge  # noqa F401  # pragma: no cover

logger = get_logger("StandardizeY")


class StandardizeY(Transform):
    """Standardize Y, separately for each metric.

    Transform is done in-place.
    """
    def __init__(
        self,
        search_space: SearchSpace,
        observation_features: List[ObservationFeatures],
        observation_data: List[ObservationData],
        config: Optional[TConfig] = None,
    ) -> None:
        if len(observation_data) == 0:
示例#9
0
from ax.core.types import (
    ComparisonOp,
    TModelCov,
    TModelMean,
    TModelPredict,
    TModelPredictArm,
    TParameterization,
)
from ax.metrics.branin import BraninMetric
from ax.metrics.factorial import FactorialMetric
from ax.metrics.hartmann6 import Hartmann6Metric
from ax.modelbridge.factory import Cont_X_trans, get_factorial, get_sobol
from ax.runners.synthetic import SyntheticRunner
from ax.utils.common.logger import get_logger

logger = get_logger("ae_experiment")

# Experiments


def get_experiment() -> Experiment:
    return Experiment(
        name="test",
        search_space=get_search_space(),
        optimization_config=get_optimization_config(),
        status_quo=get_status_quo(),
        description="test description",
        tracking_metrics=[Metric(name="tracking")],
        is_test=True,
    )
示例#10
0
import numpy as np
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.logger import get_logger
from scipy.stats import norm


if TYPE_CHECKING:
    # import as module to make sphinx-autodoc-typehints happy
    from ax.modelbridge import base as base_modelbridge  # noqa F401  # pragma: no cover


logger = get_logger("LogY")


# TODO(jej): Add OptimizationConfig validation - can't transform outcome constraints.
class InverseGaussianCdfY(Transform):
    """Apply inverse CDF transform to Y.

    This means that we model uniform distributions as gaussian-distributed.
    """

    def __init__(
        self,
        search_space: SearchSpace,
        observation_features: List[ObservationFeatures],
        observation_data: List[ObservationData],
        config: Optional[TConfig] = None,
示例#11
0
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict
from unittest.mock import patch

from ax.utils.common.kwargs import validate_kwarg_typing, warn_on_kwargs
from ax.utils.common.logger import get_logger
from ax.utils.common.testutils import TestCase

logger = get_logger("ax.utils.common.kwargs")


class TestKwargUtils(TestCase):
    def test_validate_kwarg_typing(self):
        def typed_callable(arg1: int, arg2: str = None) -> None:
            pass

        def typed_callable_with_dict(arg3: int, arg4: Dict[str, int]) -> None:
            pass

        def typed_callable_valid(arg3: int, arg4: str = None) -> None:
            pass

        def typed_callable_dup_keyword(arg2: int, arg4: str = None) -> None:
            pass

        def typed_callable_with_callable(
示例#12
0
from ax.modelbridge.modelbridge_utils import (
    extract_objective_thresholds,
    validate_and_apply_final_transform,
)
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.transforms.base import Transform
from ax.models.torch.frontier_utils import (
    TFrontierEvaluator,
    get_default_frontier_evaluator,
)
from ax.models.torch_base import TorchModel
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast_optional, not_none

logger = get_logger("MultiObjectiveTorchModelBridge")


class MultiObjectiveTorchModelBridge(TorchModelBridge):
    """A model bridge for using multi-objective torch-based models.

    Specifies an interface that is implemented by MultiObjectiveTorchModel. In
    particular, model should have methods fit, predict, and gen. See
    MultiObjectiveTorchModel for the API for each of these methods.

    Requires that all parameters have been transformed to RangeParameters
    or FixedParameters with float type and no log scale.

    This class converts Ax parameter types to torch tensors before passing
    them to the model.
    """
示例#13
0
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun, extract_arm_predictions
from ax.core.observation import (
    Observation,
    ObservationData,
    ObservationFeatures,
    observations_from_data,
)
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig, TModelCov, TModelMean, TModelPredict
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.logger import get_logger

logger = get_logger("ModelBridge")


class ModelBridge(ABC):
    """The main object for using models in Ax.

    ModelBridge specifies 3 methods for using models:

    - predict: Make model predictions. This method is not optimized for
      speed and so should be used primarily for plotting or similar tasks
      and not inside an optimization loop.
    - gen: Use the model to generate new candidates.
    - cross_validate: Do cross validation to assess model predictions.

    ModelBridge converts Ax types like Data and Arm to types that are
    meant to be consumed by the models. The data sent to the model will depend
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Dict, List

import numpy as np
from ax.core.observation import ObservationData, ObservationFeatures
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.logger import get_logger

logger = get_logger("IVW")


def ivw_metric_merge(obsd: ObservationData,
                     conflicting_noiseless: str = "warn") -> ObservationData:
    """Merge multiple observations of a metric with inverse variance weighting.

    Correctly updates the covariance of the new merged estimates:
    ybar1 = Sum_i w_i * y_i
    ybar2 = Sum_j w_j * y_j
    cov[ybar1, ybar2] = Sum_i Sum_j w_i * w_j * cov[y_i, y_j]

    w_i will be infinity if any variance is 0. If one variance is 0., then
    the IVW estimate is the corresponding mean. If there are multiple
    measurements with 0 variance but means are all the same, then IVW estimate
    is that mean. If there are multiple measurements and means differ, behavior
    depends on argument conflicting_noiseless. "ignore" and "warn" will use
    the first of the measurements as the IVW estimate. "warn" will additionally
    log a warning. "raise" will raise an exception.

    Args:
示例#15
0
文件: test_logger.py 项目: xiecong/Ax
 def testLogger(self):
     logger = get_logger(__name__)
     logger.warning = mock.MagicMock(name="warning")
     logger.warning(self.warning_string)
     logger.warning.assert_called_once_with(self.warning_string)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Any, Dict

import pandas as pd
import plotly.graph_objs as go
from ax.exceptions.core import NoDataError
from ax.modelbridge import ModelBridge
from ax.plot.base import AxPlotConfig, AxPlotTypes
from ax.utils.common.logger import get_logger
from plotly import tools


logger = get_logger("FeatureImportance")


def plot_feature_importance(df: pd.DataFrame, title: str) -> AxPlotConfig:
    if df.empty:
        raise NoDataError("No Data on Feature Importances found.")
    df.set_index(df.columns[0], inplace=True)
    data = [
        go.Bar(y=df.index, x=df[column_name], name=column_name, orientation="h")
        for column_name in df.columns
    ]
    fig = tools.make_subplots(
        rows=len(df.columns),
        cols=1,
        subplot_titles=df.columns,
        print_grid=False,
        shared_xaxes=True,
示例#17
0
from ax import optimize
from logging import CRITICAL
from ax.utils.common.logger import get_logger
import warnings
rt = get_logger(name='ax')
rt.setLevel(CRITICAL)
warnings.filterwarnings("ignore", category=UserWarning)


def test_intro_example():
    """ https://ax.dev/ """
    best_parameters, best_values, experiment, model = optimize(
        parameters=[
            {
                "name": "x1",
                "type": "range",
                "bounds": [-10.0, 10.0],
            },
            {
                "name": "x2",
                "type": "range",
                "bounds": [-10.0, 10.0],
            },
        ],
        # Booth function
        evaluation_function=lambda p: (p["x1"] + 2 * p["x2"] - 7)**2 +
        (2 * p["x1"] + p["x2"] - 5)**2,
        minimize=True,
    )
    return best_values
示例#18
0
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from ax.plot.base import AxPlotConfig
from ax.plot.render import _js_requires, _wrap_js, plot_config_to_html
from ax.utils.common.logger import get_logger
from IPython.display import display

logger = get_logger("ipy_plotting")


def init_notebook_plotting(offline=False):
    """Initialize plotting in notebooks, either in online or offline mode."""
    display_bundle = {"text/html": _wrap_js(_js_requires(offline=offline))}
    display(display_bundle, raw=True)
    logger.info("Injecting Plotly library into cell. "
                "Do not overwrite or delete cell.")


def render(plot_config: AxPlotConfig, inject_helpers=False) -> None:
    """Render plot config."""
    display_bundle = {
        "text/html":
        plot_config_to_html(plot_config, inject_helpers=inject_helpers)
    }
    display(display_bundle, raw=True)
示例#19
0
文件: kwargs.py 项目: facebook/Ax
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from inspect import Parameter, signature
from typing import Any, Callable, Dict, Iterable, List, Optional

from ax.utils.common.logger import get_logger
from typeguard import check_type

logger = get_logger(__name__)

TKwargs = Dict[str, Any]


def consolidate_kwargs(kwargs_iterable: Iterable[Optional[Dict[str, Any]]],
                       keywords: Iterable[str]) -> Dict[str, Any]:
    """Combine an iterable of kwargs into a single dict of kwargs, where kwargs
    by duplicate keys that appear later in the iterable get priority over the
    ones that appear earlier and only kwargs referenced in keywords will be
    used. This allows to combine somewhat redundant sets of kwargs, where a
    user-set kwarg, for instance, needs to override a default kwarg.

    >>> consolidate_kwargs(
    ...     kwargs_iterable=[{'a': 1, 'b': 2}, {'b': 3, 'c': 4, 'd': 5}],
    ...     keywords=['a', 'b', 'd']
    ... )
    {'a': 1, 'b': 3, 'd': 5}
    """
示例#20
0
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import List, Optional

import numpy as np
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast

logger = get_logger("Winsorize")


class Winsorize(Transform):
    """Clip the mean values for each metric to lay within the limits provided in
    the config as a tuple of (lower bound percentile, upper bound percentile).
    Config should include those bounds under key "winsorization_limits".
    """
    def __init__(
        self,
        search_space: SearchSpace,
        observation_features: List[ObservationFeatures],
        observation_data: List[ObservationData],
        config: Optional[TConfig] = None,
    ) -> None:
        if len(observation_data) == 0:
            raise ValueError(
                "Winsorize transform requires non-empty observation data.")
示例#21
0
 def testLoggerWithFile(self):
     with NamedTemporaryFile() as tf:
         logger = get_logger(__name__, tf.name)
         logger.warning(self.warning_string)
         self.assertIn(self.warning_string, str(tf.read()))
         tf.close()
示例#22
0
from botorch.optim.fit import fit_gpytorch_scipy
from botorch.optim.initializers import initialize_q_batch_nonneg
from botorch.optim.numpy_converter import module_to_array
from botorch.optim.optimize import optimize_acqf
from botorch.optim.utils import _scipy_objective_and_grad
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels.kernel import Kernel
from gpytorch.kernels.rbf_kernel import postprocess_rbf
from gpytorch.kernels.scale_kernel import ScaleKernel
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from scipy.optimize import approx_fprime
from torch import Tensor


logger = get_logger(name="ALEBO")


class ALEBOKernel(Kernel):
    """The kernel for ALEBO.

    Suppose there exists an ARD RBF GP on an (unknown) linear embedding with
    projection matrix A. We make function evaluations in a different linear
    embedding with projection matrix B (known). This is the appropriate kernel
    for fitting those data.

    This kernel computes a Mahalanobis distance, and the (d x d) PD distance
    matrix Gamma is a parameter that must be fit. This is done by fitting its
    upper Cholesky decomposition, U.

    Args:
示例#23
0
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import List, Optional

from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast
from scipy import stats


logger = get_logger("PercentileY")


# TODO(jej): Add OptimizationConfig validation - can't transform outcome constraints.
class PercentileY(Transform):
    """Map Y values to percentiles based on their empirical CDF.
    """

    def __init__(
        self,
        search_space: SearchSpace,
        observation_features: List[ObservationFeatures],
        observation_data: List[ObservationData],
        config: Optional[TConfig] = None,
    ) -> None:
        if len(observation_data) == 0:
示例#24
0
文件: helper.py 项目: stevemandala/Ax
import math
from collections import Counter
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import numpy as np
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.parameter import ChoiceParameter, FixedParameter, RangeParameter
from ax.core.types import TParameterization
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.transforms.ivw import IVW
from ax.plot.base import DECIMALS, PlotData, PlotInSampleArm, PlotOutOfSampleArm, Z
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none

logger = get_logger(name="PlotHelper")

# Typing alias
RawData = List[Dict[str, Union[str, float]]]

TNullableGeneratorRunsDict = Optional[Dict[str, GeneratorRun]]


def _format_dict(param_dict: TParameterization,
                 name: str = "Parameterization") -> str:
    """Format a dictionary for labels.

    Args:
        param_dict: Dictionary to be formatted
        name: String name of the thing being formatted.
示例#25
0
import numpy as np
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TConfig, TParamValue
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.standardize_y import compute_standardization_parameters
from ax.utils.common.logger import get_logger

if TYPE_CHECKING:
    # import as module to make sphinx-autodoc-typehints happy
    from ax import modelbridge as modelbridge_module  # noqa F401  # pragma: no cover

logger = get_logger("StratifiedStandardizeY")


class StratifiedStandardizeY(Transform):
    """Standardize Y, separately for each metric and for each value of a
    ChoiceParameter.

    The name of the parameter by which to stratify the standardization can be
    specified in config["parameter_name"]. If not specified, will use a task
    parameter if search space contains exactly 1 task parameter, and will raise
    an exception otherwise.

    The stratification parameter must be fixed during generation if there are
    outcome constraints, in order to apply the standardization to the
    constraints.
示例#26
0
import logging
from math import ceil
from typing import cast, Optional, Tuple, Type, Union

from ax.core.experiment import Experiment
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Cont_X_trans, Models, Y_trans
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.winsorize import Winsorize
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none

logger: logging.Logger = get_logger(__name__)

DEFAULT_BAYESIAN_PARALLELISM = 3
# `BO_MIXED` optimizes all range parameters once for each combination of choice
# parameters, then takes the optimum of those optima. The cost associated with this
# method grows with the number of combinations, and so it is only used when the
# number of enumerated discrete combinations is below some maximum value.
MAX_DISCRETE_ENUMERATIONS_MIXED = 65
MAX_DISCRETE_ENUMERATIONS_NO_CONTINUOUS_OPTIMIZATION = 1e4
SAASBO_INCOMPATIBLE_MESSAGE = (
    "SAASBO is incompatible with {} generation strategy. "
    "Disregarding user input `use_saasbo = True`.")


def _make_sobol_step(
    num_trials: int = -1,
示例#27
0
文件: statstools.py 项目: xiecong/Ax
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Tuple, Union

import numpy as np
import pandas as pd
from ax.utils.common.logger import get_logger

logger = get_logger("Statstools")
num_mixed = Union[np.ndarray, List[float]]


def inverse_variance_weight(
        means: np.ndarray,
        variances: np.ndarray,
        conflicting_noiseless: str = "warn") -> Tuple[float, float]:
    """Perform inverse variance weighting.

    Args:
        means: The means of the observations.
        variances: The variances of the observations.
        conflicting_noiseless: How to handle the case of
            multiple observations with zero variance but different means.
            Options are "warn" (default), "ignore" or "raise".

    """
    if conflicting_noiseless not in {"warn", "ignore", "raise"}: