예제 #1
0
파일: mixed_input.py 프로젝트: zzmcdc/ray
    def __init__(self, dist: Dict[JsonReader, float], ioctx: IOContext):
        """Initialize a MixedInput.

        Args:
            dist (dict): dict mapping JSONReader paths or "sampler" to
                probabilities. The probabilities must sum to 1.0.
            ioctx (IOContext): current IO context object.
        """
        if sum(dist.values()) != 1.0:
            raise ValueError("Values must sum to 1.0: {}".format(dist))
        self.choices = []
        self.p = []
        for k, v in dist.items():
            if k == "sampler":
                self.choices.append(ioctx.default_sampler_input())
            else:
                self.choices.append(JsonReader(k))
            self.p.append(v)
예제 #2
0
    def __init__(self, dist: Dict[JsonReader, float], ioctx: IOContext):
        """Initialize a MixedInput.

        Args:
            dist (dict): dict mapping JSONReader paths or "sampler" to
                probabilities. The probabilities must sum to 1.0.
            ioctx (IOContext): current IO context object.
        """
        if sum(dist.values()) != 1.0:
            raise ValueError("Values must sum to 1.0: {}".format(dist))
        self.choices = []
        self.p = []
        for k, v in dist.items():
            if k == "sampler":
                self.choices.append(ioctx.default_sampler_input())
            elif isinstance(k, FunctionType):
                self.choices.append(k(ioctx))
            elif isinstance(k, str) and registry_contains_input(k):
                input_creator = registry_get_input(k)
                self.choices.append(input_creator(ioctx))
            else:
                self.choices.append(JsonReader(k, ioctx))
            self.p.append(v)