def RExt(dtype: DType,
         rgen=None,
         spec: SearchSpec = None,
         depth: int = 1,
         mode: str = None,
         tracker: OpTracker = None,
         arg_name: str = None,
         identifier: str = None,
         constraint: Callable[[Any], Any] = None,
         **kwargs):

    if constraint is None:

        def constraint(x):
            return True

    if mode != 'training-data':
        raise AutoPandasException("Unrecognized mode {} in RExt".format(mode))

    pool: List[Optional[Value]] = []
    for idx, val in enumerate(spec.inputs):
        if not (dtype.hasinstance(val) and constraint(val)):
            continue
        pool.append(Fetcher(val=val, source='inps', idx=idx))

    for idx, val in enumerate(spec.intermediates[:depth - 1]):
        if not (dtype.hasinstance(val) and constraint(val)):
            continue
        pool.append(Fetcher(val=val, source='intermediates', idx=idx))

    if rgen is not None:
        pool.append(None)

    random.shuffle(pool)
    label = 'ext_' + arg_name + '_' + identifier
    rlabel = 'rext_' + arg_name + '_' + identifier
    for selection in pool:
        tracker.record.pop(label, None)
        tracker.record.pop(rlabel, None)
        if selection is None:
            #  We've decided to create a new input altogether
            val = next(rgen)
            tracker.record[rlabel] = {'val': val, 'arg_name': arg_name}
            yield NewInp(val)

        else:
            selection: Fetcher
            tracker.record[label] = {
                'source': selection.source,
                'idx': selection.idx
            }
            yield selection
Ejemplo n.º 2
0
def Ext(dtype: DType,
        spec: SearchSpec = None,
        depth: int = 1,
        mode: str = None,
        tracker: OpTracker = None,
        arg_name: str = None,
        identifier: str = None,
        constraint: Callable[[Any], Any] = None,
        **kwargs):
    if constraint is None:

        def constraint(x):
            return True

    if mode == 'exhaustive' or mode == 'inference':
        for idx, val in enumerate(reversed(spec.intermediates[:depth - 1])):
            idx = depth - idx - 2
            if not (dtype.hasinstance(val) and constraint(val)):
                continue
            yield Fetcher(val=val, source='intermediates', idx=idx)

        for idx, val in enumerate(spec.inputs):
            if not (dtype.hasinstance(val) and constraint(val)):
                continue
            yield Fetcher(val=val, source='inps', idx=idx)

    elif mode == 'arguments-training-data':
        label = 'ext_' + arg_name + '_' + identifier
        if label not in tracker.record:
            raise AutoPandasInversionFailedException(
                "Could not find label {} in tracker".format(label))

        record = tracker.record[label]
        idx = record['idx']
        if record['source'] == 'inps':
            yield Fetcher(val=spec.inputs[idx], source='inps', idx=idx)

        elif record['source'] == 'intermediates':
            yield Fetcher(val=spec.intermediates[idx],
                          source='intermediates',
                          idx=idx)

        return

    elif mode == 'arguments-training-data-best-effort':
        training_spec: ArgTrainingSpec = spec
        label = 'ext_' + arg_name + '_' + identifier
        for idx, val in enumerate(spec.inputs):
            if not (dtype.hasinstance(val) and constraint(val)):
                continue

            if Checker.check(val, training_spec.args[arg_name]):
                yield Fetcher(val=val, source='inps', idx=idx)
                return

        for idx, val in enumerate(spec.intermediates[:depth - 1]):
            if not (dtype.hasinstance(val) and constraint(val)):
                continue

            if Checker.check(val, training_spec.args[arg_name]):
                yield Fetcher(val=val, source='intermediates', idx=idx)
                return

        raise AutoPandasInversionFailedException(
            "Could not invert generator for {} at {}".format(arg_name, label))