Exemplo n.º 1
0
 def checkpoint():
     if checkpoint_asserts:
         self.assert_integrity_idxs_take()
         if node in self.idxs_memo:
             toposort(self.idxs_memo[node])
         if node in self.take_memo:
             for take in self.take_memo[node]:
                 toposort(take)
Exemplo n.º 2
0
 def checkpoint():
     if checkpoint_asserts:
         self.assert_integrity_idxs_take()
         if node in self.idxs_memo:
             toposort(self.idxs_memo[node])
         if node in self.take_memo:
             for take in self.take_memo[node]:
                 toposort(take)
Exemplo n.º 3
0
    def __init__(self, bandit, seed=seed, cmd=None, workdir=None):
        self.bandit = bandit
        self.seed = seed
        self.rng = np.random.RandomState(self.seed)
        self.cmd = cmd
        self.workdir = workdir
        self.s_new_ids = pyll.Literal('new_ids')  # -- list at eval-time
        before = pyll.dfs(self.bandit.expr)
        # -- raises exception if expr contains cycles
        pyll.toposort(self.bandit.expr)
        vh = self.vh = VectorizeHelper(self.bandit.expr, self.s_new_ids)
        # -- raises exception if v_expr contains cycles
        pyll.toposort(vh.v_expr)

        idxs_by_label = vh.idxs_by_label()
        vals_by_label = vh.vals_by_label()
        after = pyll.dfs(self.bandit.expr)
        # -- try to detect if VectorizeHelper screwed up anything inplace
        assert before == after
        assert set(idxs_by_label.keys()) == set(vals_by_label.keys())
        assert set(idxs_by_label.keys()) == set(self.bandit.params.keys())

        # -- make the graph runnable and SON-encodable
        # N.B. operates inplace
        self.s_idxs_vals = recursive_set_rng_kwarg(
                scope.pos_args(idxs_by_label, vals_by_label),
                pyll.as_apply(self.rng))

        # -- raises an exception if no topological ordering exists
        pyll.toposort(self.s_idxs_vals)
    def __init__(self, bandit, seed=seed, cmd=None, workdir=None):
        self.bandit = bandit
        self.seed = seed
        self.rng = np.random.RandomState(self.seed)
        self.cmd = cmd
        self.workdir = workdir
        self.s_new_ids = pyll.Literal('new_ids')  # -- list at eval-time
        before = pyll.dfs(self.bandit.expr)
        # -- raises exception if expr contains cycles
        pyll.toposort(self.bandit.expr)
        vh = self.vh = VectorizeHelper(self.bandit.expr, self.s_new_ids)
        # -- raises exception if v_expr contains cycles
        pyll.toposort(vh.v_expr)

        idxs_by_label = vh.idxs_by_label()
        vals_by_label = vh.vals_by_label()
        after = pyll.dfs(self.bandit.expr)
        # -- try to detect if VectorizeHelper screwed up anything inplace
        assert before == after
        assert set(idxs_by_label.keys()) == set(vals_by_label.keys())
        assert set(idxs_by_label.keys()) == set(self.bandit.params.keys())

        # -- make the graph runnable and SON-encodable
        # N.B. operates inplace
        self.s_idxs_vals = recursive_set_rng_kwarg(
            scope.pos_args(idxs_by_label, vals_by_label),
            pyll.as_apply(self.rng))

        # -- raises an exception if no topological ordering exists
        pyll.toposort(self.s_idxs_vals)
Exemplo n.º 5
0
def space_eval(space, hp_assignment):
    """Compute a point in a search space from a hyperparameter assignment.

    Parameters:
    -----------
    space - a pyll graph involving hp nodes (see `pyll_utils`).

    hp_assignment - a dictionary mapping hp node labels to values.
    """
    nodes = pyll.toposort(space)
    memo = {}
    for node in nodes:
        if node.name == 'hyperopt_param':
            label = node.arg['label'].eval()
            if label in hp_assignment:
                memo[node] = hp_assignment[label]
    rval = pyll.rec_eval(space, memo=memo)
    return rval
Exemplo n.º 6
0
def space_eval(space, hp_assignment):
    """Compute a point in a search space from a hyperparameter assignment.

    Parameters:
    -----------
    space - a pyll graph involving hp nodes (see `pyll_utils`).

    hp_assignment - a dictionary mapping hp node labels to values.
    """
    nodes = pyll.toposort(space)
    memo = {}
    for node in nodes:
        if node.name == 'hyperopt_param':
            label = node.arg['label'].eval()
            if label in hp_assignment:
                memo[node] = hp_assignment[label]
    rval = pyll.rec_eval(space, memo=memo)
    return rval
Exemplo n.º 7
0
    def __init__(self,
                 fn,
                 expr,
                 args=[],
                 workdir=None,
                 pass_expr_memo_ctrl=None,
                 **bandit_kwargs):
        self.cmd = ('domain_attachment', 'FMinIter_Domain')
        self.fn = fn
        self.expr = expr
        self.args = args
        if pass_expr_memo_ctrl is None:
            self.pass_expr_memo_ctrl = getattr(fn, 'fmin_pass_expr_memo_ctrl',
                                               False)
        else:
            self.pass_expr_memo_ctrl = pass_expr_memo_ctrl
        base.Bandit.__init__(self, expr, do_checks=False, **bandit_kwargs)

        # -- This code was stolen from base.BanditAlgo, a class which may soon
        #    be gone
        self.workdir = workdir
        self.s_new_ids = pyll.Literal('new_ids')  # -- list at eval-time
        before = pyll.dfs(self.expr)
        # -- raises exception if expr contains cycles
        pyll.toposort(self.expr)
        vh = self.vh = VectorizeHelper(self.expr, self.s_new_ids)
        # -- raises exception if v_expr contains cycles
        pyll.toposort(vh.v_expr)

        idxs_by_label = vh.idxs_by_label()
        vals_by_label = vh.vals_by_label()
        after = pyll.dfs(self.expr)
        # -- try to detect if VectorizeHelper screwed up anything inplace
        assert before == after
        assert set(idxs_by_label.keys()) == set(vals_by_label.keys())
        assert set(idxs_by_label.keys()) == set(self.params.keys())

        # -- make the graph runnable and SON-encodable
        # N.B. operates inplace
        self.s_idxs_vals = recursive_set_rng_kwarg(
            pyll.scope.pos_args(idxs_by_label, vals_by_label),
            pyll.as_apply(self.rng))

        # -- raises an exception if no topological ordering exists
        pyll.toposort(self.s_idxs_vals)
Exemplo n.º 8
0
    def __init__(self, fn, expr, args=[],
            workdir=None,
            pass_expr_memo_ctrl=None,
            **bandit_kwargs):
        self.cmd = ('domain_attachment', 'FMinIter_Domain')
        self.fn = fn
        self.expr = expr
        self.args = args
        if pass_expr_memo_ctrl is None:
            self.pass_expr_memo_ctrl = getattr(fn,
                    'fmin_pass_expr_memo_ctrl', False)
        else:
            self.pass_expr_memo_ctrl = pass_expr_memo_ctrl
        base.Bandit.__init__(self, expr, do_checks=False, **bandit_kwargs)

        # -- This code was stolen from base.BanditAlgo, a class which may soon
        #    be gone
        self.workdir = workdir
        self.s_new_ids = pyll.Literal('new_ids')  # -- list at eval-time
        before = pyll.dfs(self.expr)
        # -- raises exception if expr contains cycles
        pyll.toposort(self.expr)
        vh = self.vh = VectorizeHelper(self.expr, self.s_new_ids)
        # -- raises exception if v_expr contains cycles
        pyll.toposort(vh.v_expr)

        idxs_by_label = vh.idxs_by_label()
        vals_by_label = vh.vals_by_label()
        after = pyll.dfs(self.expr)
        # -- try to detect if VectorizeHelper screwed up anything inplace
        assert before == after
        assert set(idxs_by_label.keys()) == set(vals_by_label.keys())
        assert set(idxs_by_label.keys()) == set(self.params.keys())

        # -- make the graph runnable and SON-encodable
        # N.B. operates inplace
        self.s_idxs_vals = recursive_set_rng_kwarg(
                pyll.scope.pos_args(idxs_by_label, vals_by_label),
                pyll.as_apply(self.rng))

        # -- raises an exception if no topological ordering exists
        pyll.toposort(self.s_idxs_vals)
Exemplo n.º 9
0
    def __init__(self, fn, expr,
                 workdir=None,
                 pass_expr_memo_ctrl=None,
                 name=None,
                 loss_target=None,
                 ):
        """
        Paramaters
        ----------

        fn : callable
            This stores the `fn` argument to `fmin`. (See `hyperopt.fmin.fmin`)

        expr : hyperopt.pyll.Apply
            This is the `space` argument to `fmin`. (See `hyperopt.fmin.fmin`)

        workdir : string (or None)
            If non-None, the current working directory will be `workdir`while
            `expr` and `fn` are evaluated. (XXX Currently only respected by
            jobs run via MongoWorker)

        pass_expr_memo_ctrl : bool
            If True, `fn` will be called like this:
            `fn(self.expr, memo, ctrl)`,
            where `memo` is a dictionary mapping `Apply` nodes to their
            computed values, and `ctrl` is a `Ctrl` instance for communicating
            with a Trials database.  This lower-level calling convention is
            useful if you want to call e.g. `hyperopt.pyll.rec_eval` yourself
            in some customized way.

        name : string (or None)
            Label, used for pretty-printing.

        loss_target : float (or None)
            The actual or estimated minimum of `fn`.
            Some optimization algorithms may behave differently if their first
            objective is to find an input that achieves a certain value,
            rather than the more open-ended objective of pure minimization.
            XXX: Move this from Domain to be an fmin arg.

        """
        self.fn = fn
        if pass_expr_memo_ctrl is None:
            self.pass_expr_memo_ctrl = getattr(fn,
                                               'fmin_pass_expr_memo_ctrl',
                                               False)
        else:
            self.pass_expr_memo_ctrl = pass_expr_memo_ctrl

        self.expr = pyll.as_apply(expr)

        self.params = {}
        for node in pyll.dfs(self.expr):
            if node.name == 'hyperopt_param':
                label = node.arg['label'].obj
                if label in self.params:
                    raise DuplicateLabel(label)
                self.params[label] = node.arg['obj']

        self.loss_target = loss_target
        self.name = name

        self.workdir = workdir
        self.s_new_ids = pyll.Literal('new_ids')  # -- list at eval-time
        before = pyll.dfs(self.expr)
        # -- raises exception if expr contains cycles
        pyll.toposort(self.expr)
        vh = self.vh = VectorizeHelper(self.expr, self.s_new_ids)
        # -- raises exception if v_expr contains cycles
        pyll.toposort(vh.v_expr)

        idxs_by_label = vh.idxs_by_label()
        vals_by_label = vh.vals_by_label()
        after = pyll.dfs(self.expr)
        # -- try to detect if VectorizeHelper screwed up anything inplace
        assert before == after
        assert set(idxs_by_label.keys()) == set(vals_by_label.keys())
        assert set(idxs_by_label.keys()) == set(self.params.keys())

        self.s_rng = pyll.Literal('rng-placeholder')
        # -- N.B. operates inplace:
        self.s_idxs_vals = recursive_set_rng_kwarg(
            pyll.scope.pos_args(idxs_by_label, vals_by_label),
            self.s_rng)

        # -- raises an exception if no topological ordering exists
        pyll.toposort(self.s_idxs_vals)

        # -- Protocol for serialization.
        #    self.cmd indicates to e.g. MongoWorker how this domain
        #    should be [un]serialized.
        #    XXX This mechanism deserves review as support for ipython
        #        workers improves.
        self.cmd = ('domain_attachment', 'FMinIter_Domain')
Exemplo n.º 10
0
    def __init__(
        self,
        fn,
        expr,
        workdir=None,
        pass_expr_memo_ctrl=None,
        name=None,
        loss_target=None,
    ):
        """
        Paramaters
        ----------

        fn : callable
            This stores the `fn` argument to `fmin`. (See `hyperopt.fmin.fmin`)

        expr : hyperopt.pyll.Apply
            This is the `space` argument to `fmin`. (See `hyperopt.fmin.fmin`)

        workdir : string (or None)
            If non-None, the current working directory will be `workdir`while
            `expr` and `fn` are evaluated. (XXX Currently only respected by
            jobs run via MongoWorker)

        pass_expr_memo_ctrl : bool
            If True, `fn` will be called like this:
            `fn(self.expr, memo, ctrl)`,
            where `memo` is a dictionary mapping `Apply` nodes to their
            computed values, and `ctrl` is a `Ctrl` instance for communicating
            with a Trials database.  This lower-level calling convention is
            useful if you want to call e.g. `hyperopt.pyll.rec_eval` yourself
            in some customized way.

        name : string (or None)
            Label, used for pretty-printing.

        loss_target : float (or None)
            The actual or estimated minimum of `fn`.
            Some optimization algorithms may behave differently if their first
            objective is to find an input that achieves a certain value,
            rather than the more open-ended objective of pure minimization.
            XXX: Move this from Domain to be an fmin arg.

        """
        self.fn = fn
        if pass_expr_memo_ctrl is None:
            self.pass_expr_memo_ctrl = getattr(fn, 'fmin_pass_expr_memo_ctrl',
                                               False)
        else:
            self.pass_expr_memo_ctrl = pass_expr_memo_ctrl

        self.expr = pyll.as_apply(expr)

        self.params = {}
        for node in pyll.dfs(self.expr):
            if node.name == 'hyperopt_param':
                label = node.arg['label'].obj
                if label in self.params:
                    raise DuplicateLabel(label)
                self.params[label] = node.arg['obj']

        self.loss_target = loss_target
        self.name = name

        self.workdir = workdir
        self.s_new_ids = pyll.Literal('new_ids')  # -- list at eval-time
        before = pyll.dfs(self.expr)
        # -- raises exception if expr contains cycles
        pyll.toposort(self.expr)
        vh = self.vh = VectorizeHelper(self.expr, self.s_new_ids)
        # -- raises exception if v_expr contains cycles
        pyll.toposort(vh.v_expr)

        idxs_by_label = vh.idxs_by_label()
        vals_by_label = vh.vals_by_label()
        after = pyll.dfs(self.expr)
        # -- try to detect if VectorizeHelper screwed up anything inplace
        assert before == after
        assert set(idxs_by_label.keys()) == set(vals_by_label.keys())
        assert set(idxs_by_label.keys()) == set(self.params.keys())

        self.s_rng = pyll.Literal('rng-placeholder')
        # -- N.B. operates inplace:
        self.s_idxs_vals = recursive_set_rng_kwarg(
            pyll.scope.pos_args(idxs_by_label, vals_by_label), self.s_rng)

        # -- raises an exception if no topological ordering exists
        pyll.toposort(self.s_idxs_vals)

        # -- Protocol for serialization.
        #    self.cmd indicates to e.g. MongoWorker how this domain
        #    should be [un]serialized.
        #    XXX This mechanism deserves review as support for ipython
        #        workers improves.
        self.cmd = ('domain_attachment', 'FMinIter_Domain')
Exemplo n.º 11
0
    def build_idxs_vals(self, node, wanted_idxs):
        """
        This recursive procedure should be called on an output-node.
        """
        checkpoint_asserts = False

        def checkpoint():
            if checkpoint_asserts:
                self.assert_integrity_idxs_take()
                if node in self.idxs_memo:
                    toposort(self.idxs_memo[node])
                if node in self.take_memo:
                    for take in self.take_memo[node]:
                        toposort(take)

        checkpoint()

        # wanted_idxs are fixed, whereas idxs_memo
        # is full of unions, that can grow in subsequent recursive
        # calls to build_idxs_vals with node as argument.
        assert wanted_idxs != self.idxs_memo.get(node)

        # -- easy exit case
        if node.name == 'hyperopt_param':
            # -- ignore, not vectorizing
            return self.build_idxs_vals(node.arg['obj'], wanted_idxs)

        # -- easy exit case
        elif node.name == 'hyperopt_result':
            # -- ignore, not vectorizing
            return self.build_idxs_vals(node.arg['obj'], wanted_idxs)

        # -- literal case: always take from universal set
        elif node.name == 'literal':
            if node in self.idxs_memo:
                all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                self.take_memo[node].append(wanted_vals)
                checkpoint()
            else:
                # -- initialize idxs_memo to full set
                all_idxs = self.expr_idxs
                n_times = scope.len(all_idxs)
                # -- put array_union into graph for consistency, though it is
                # not necessary
                all_idxs = scope.array_union(all_idxs)
                self.idxs_memo[node] = all_idxs
                all_vals = scope.asarray(scope.repeat(n_times, node))
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                assert node not in self.take_memo
                self.take_memo[node] = [wanted_vals]
                checkpoint()
            return wanted_vals

        # -- switch case: complicated
        elif node.name == 'switch':
            if (node in self.idxs_memo
                and wanted_idxs in self.idxs_memo[node].pos_args):
                # -- phew, easy case
                all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                self.take_memo[node].append(wanted_vals)
                checkpoint()
            else:
                # -- we need to add some indexes
                if node in self.idxs_memo:
                    all_idxs = self.idxs_memo[node]
                    assert all_idxs.name == 'array_union'
                    all_idxs.pos_args.append(wanted_idxs)
                else:
                    all_idxs = scope.array_union(wanted_idxs)

                choice = node.pos_args[0]
                all_choices = self.build_idxs_vals(choice, all_idxs)

                options = node.pos_args[1:]
                args_idxs = scope.vchoice_split(all_idxs, all_choices,
                                                len(options))
                all_vals = scope.vchoice_merge(all_idxs, all_choices)
                for opt_ii, idxs_ii in zip(options, args_idxs):
                    all_vals.pos_args.append(
                        as_apply([
                            idxs_ii,
                            self.build_idxs_vals(opt_ii, idxs_ii),
                        ]))

                wanted_vals = scope.idxs_take(
                    all_idxs,  # -- may grow in future
                    all_vals,  # -- may be replaced in future
                    wanted_idxs)  # -- fixed.
                if node in self.idxs_memo:
                    assert self.idxs_memo[node].name == 'array_union'
                    self.idxs_memo[node].pos_args.append(wanted_idxs)
                    for take in self.take_memo[node]:
                        assert take.name == 'idxs_take'
                        take.pos_args[1] = all_vals
                    self.take_memo[node].append(wanted_vals)
                else:
                    self.idxs_memo[node] = all_idxs
                    self.take_memo[node] = [wanted_vals]
                checkpoint()

        # -- general case
        else:
            # -- this is a general node.
            #    It is generally handled with idxs_memo,
            #    but vectorize_stochastic may immediately transform it into
            #    a more compact form.
            if (node in self.idxs_memo
                and wanted_idxs in self.idxs_memo[node].pos_args):
                # -- phew, easy case
                for take in self.take_memo[node]:
                    if take.pos_args[2] == wanted_idxs:
                        return take
                raise NotImplementedError('how did this happen?')
                #all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                #wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                #self.take_memo[node].append(wanted_vals)
                #checkpoint()
            else:
                # XXX
                # -- determine if wanted_idxs is actually a subset of the idxs
                # that we are already computing.  This is not only an
                # optimization, but prevents the creation of cycles, which
                # would otherwise occur if we have a graph of the form
                # switch(f(a), g(a), 0). If there are other switches inside f
                # and g, does this get trickier?

                # -- assume we need to add some indexes
                checkpoint()
                if node in self.idxs_memo:
                    all_idxs = self.idxs_memo[node]

                else:
                    all_idxs = scope.array_union(wanted_idxs)
                checkpoint()

                all_vals = scope.idxs_map(all_idxs, node.name)
                for ii, aa in enumerate(node.pos_args):
                    all_vals.pos_args.append(as_apply([
                        all_idxs, self.build_idxs_vals(aa, all_idxs)]))
                    checkpoint()
                for ii, (nn, aa) in enumerate(node.named_args):
                    all_vals.named_args.append([nn, as_apply([
                        all_idxs, self.build_idxs_vals(aa, all_idxs)])])
                    checkpoint()
                all_vals = vectorize_stochastic(all_vals)

                checkpoint()
                wanted_vals = scope.idxs_take(
                    all_idxs,  # -- may grow in future
                    all_vals,  # -- may be replaced in future
                    wanted_idxs)  # -- fixed.
                if node in self.idxs_memo:
                    assert self.idxs_memo[node].name == 'array_union'
                    self.idxs_memo[node].pos_args.append(wanted_idxs)
                    toposort(self.idxs_memo[node])
                    # -- this catches the cycle bug mentioned above
                    for take in self.take_memo[node]:
                        assert take.name == 'idxs_take'
                        take.pos_args[1] = all_vals
                    self.take_memo[node].append(wanted_vals)
                else:
                    self.idxs_memo[node] = all_idxs
                    self.take_memo[node] = [wanted_vals]
                checkpoint()

        return wanted_vals
Exemplo n.º 12
0
    def build_idxs_vals(self, node, wanted_idxs):
        """
        This recursive procedure should be called on an output-node.
        """
        checkpoint_asserts = False

        def checkpoint():
            if checkpoint_asserts:
                self.assert_integrity_idxs_take()
                if node in self.idxs_memo:
                    toposort(self.idxs_memo[node])
                if node in self.take_memo:
                    for take in self.take_memo[node]:
                        toposort(take)

        checkpoint()

        # wanted_idxs are fixed, whereas idxs_memo
        # is full of unions, that can grow in subsequent recursive
        # calls to build_idxs_vals with node as argument.
        assert wanted_idxs != self.idxs_memo.get(node)

        # -- easy exit case
        if node.name == 'hyperopt_param':
            # -- ignore, not vectorizing
            return self.build_idxs_vals(node.arg['obj'], wanted_idxs)

        # -- easy exit case
        elif node.name == 'hyperopt_result':
            # -- ignore, not vectorizing
            return self.build_idxs_vals(node.arg['obj'], wanted_idxs)

        # -- literal case: always take from universal set
        elif node.name == 'literal':
            if node in self.idxs_memo:
                all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                self.take_memo[node].append(wanted_vals)
                checkpoint()
            else:
                # -- initialize idxs_memo to full set
                all_idxs = self.expr_idxs
                n_times = scope.len(all_idxs)
                # -- put array_union into graph for consistency, though it is
                # not necessary
                all_idxs = scope.array_union(all_idxs)
                self.idxs_memo[node] = all_idxs
                all_vals = scope.asarray(scope.repeat(n_times, node))
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                assert node not in self.take_memo
                self.take_memo[node] = [wanted_vals]
                checkpoint()
            return wanted_vals

        # -- switch case: complicated
        elif node.name == 'switch':
            if (node in self.idxs_memo
                    and wanted_idxs in self.idxs_memo[node].pos_args):
                # -- phew, easy case
                all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                self.take_memo[node].append(wanted_vals)
                checkpoint()
            else:
                # -- we need to add some indexes
                if node in self.idxs_memo:
                    all_idxs = self.idxs_memo[node]
                    assert all_idxs.name == 'array_union'
                    all_idxs.pos_args.append(wanted_idxs)
                else:
                    all_idxs = scope.array_union(wanted_idxs)

                choice = node.pos_args[0]
                all_choices = self.build_idxs_vals(choice, all_idxs)

                options = node.pos_args[1:]
                args_idxs = scope.vchoice_split(all_idxs, all_choices,
                                                len(options))
                all_vals = scope.vchoice_merge(all_idxs, all_choices)
                for opt_ii, idxs_ii in zip(options, args_idxs):
                    all_vals.pos_args.append(
                        as_apply([
                            idxs_ii,
                            self.build_idxs_vals(opt_ii, idxs_ii),
                        ]))

                wanted_vals = scope.idxs_take(
                    all_idxs,  # -- may grow in future
                    all_vals,  # -- may be replaced in future
                    wanted_idxs)  # -- fixed.
                if node in self.idxs_memo:
                    assert self.idxs_memo[node].name == 'array_union'
                    self.idxs_memo[node].pos_args.append(wanted_idxs)
                    for take in self.take_memo[node]:
                        assert take.name == 'idxs_take'
                        take.pos_args[1] = all_vals
                    self.take_memo[node].append(wanted_vals)
                else:
                    self.idxs_memo[node] = all_idxs
                    self.take_memo[node] = [wanted_vals]
                checkpoint()

        # -- general case
        else:
            # -- this is a general node.
            #    It is generally handled with idxs_memo,
            #    but vectorize_stochastic may immediately transform it into
            #    a more compact form.
            if (node in self.idxs_memo
                    and wanted_idxs in self.idxs_memo[node].pos_args):
                # -- phew, easy case
                for take in self.take_memo[node]:
                    if take.pos_args[2] == wanted_idxs:
                        return take
                raise NotImplementedError('how did this happen?')
                #all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                #wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                #self.take_memo[node].append(wanted_vals)
                #checkpoint()
            else:
                # XXX
                # -- determine if wanted_idxs is actually a subset of the idxs
                # that we are already computing.  This is not only an
                # optimization, but prevents the creation of cycles, which
                # would otherwise occur if we have a graph of the form
                # switch(f(a), g(a), 0). If there are other switches inside f
                # and g, does this get trickier?

                # -- assume we need to add some indexes
                checkpoint()
                if node in self.idxs_memo:
                    all_idxs = self.idxs_memo[node]

                else:
                    all_idxs = scope.array_union(wanted_idxs)
                checkpoint()

                all_vals = scope.idxs_map(all_idxs, node.name)
                for ii, aa in enumerate(node.pos_args):
                    all_vals.pos_args.append(
                        as_apply(
                            [all_idxs,
                             self.build_idxs_vals(aa, all_idxs)]))
                    checkpoint()
                for ii, (nn, aa) in enumerate(node.named_args):
                    all_vals.named_args.append([
                        nn,
                        as_apply(
                            [all_idxs,
                             self.build_idxs_vals(aa, all_idxs)])
                    ])
                    checkpoint()
                all_vals = vectorize_stochastic(all_vals)

                checkpoint()
                wanted_vals = scope.idxs_take(
                    all_idxs,  # -- may grow in future
                    all_vals,  # -- may be replaced in future
                    wanted_idxs)  # -- fixed.
                if node in self.idxs_memo:
                    assert self.idxs_memo[node].name == 'array_union'
                    self.idxs_memo[node].pos_args.append(wanted_idxs)
                    toposort(self.idxs_memo[node])
                    # -- this catches the cycle bug mentioned above
                    for take in self.take_memo[node]:
                        assert take.name == 'idxs_take'
                        take.pos_args[1] = all_vals
                    self.take_memo[node].append(wanted_vals)
                else:
                    self.idxs_memo[node] = all_idxs
                    self.take_memo[node] = [wanted_vals]
                checkpoint()

        return wanted_vals