示例#1
0
 def build_vals(self):
     for node in self.dfs_nodes:
         if node.name == "literal":
             n_times = scope.len(self.idxs_memo[node])
             vnode = scope.asarray(scope.repeat(n_times, node))
         elif node in self.choice_memo:
             # -- choices are natively vectorized
             choices = self.choice_memo[node]
             self.vals_memo[choices] = choices
             # -- this stitches together the various sub-graphs
             #    to define the original node
             vnode = scope.vchoice_merge(self.idxs_memo[node], self.choice_memo[node])
             vnode.pos_args.extend(
                 [as_apply([self.idxs_memo[inode], self.vals_memo[inode]]) for inode in node.pos_args]
             )
         else:
             vnode = scope.idxs_map(self.idxs_memo[node], node.name)
             vnode.pos_args.extend(node.pos_args)
             vnode.named_args.extend(node.named_args)
             for arg in node.inputs():
                 vnode.replace_input(arg, as_apply([self.idxs_memo[arg], self.vals_memo[arg]]))
         self.vals_memo[node] = vnode
示例#2
0
    def build_idxs_vals(self, node, wanted_idxs):
        """
        This recursive procedure should be called on an output-node.
        """
        checkpoint_asserts = False

        def checkpoint():
            if checkpoint_asserts:
                self.assert_integrity_idxs_take()
                if node in self.idxs_memo:
                    toposort(self.idxs_memo[node])
                if node in self.take_memo:
                    for take in self.take_memo[node]:
                        toposort(take)

        checkpoint()

        # wanted_idxs are fixed, whereas idxs_memo
        # is full of unions, that can grow in subsequent recursive
        # calls to build_idxs_vals with node as argument.
        assert wanted_idxs != self.idxs_memo.get(node)

        # -- easy exit case
        if node.name == 'hyperopt_param':
            # -- ignore, not vectorizing
            return self.build_idxs_vals(node.arg['obj'], wanted_idxs)

        # -- easy exit case
        elif node.name == 'hyperopt_result':
            # -- ignore, not vectorizing
            return self.build_idxs_vals(node.arg['obj'], wanted_idxs)

        # -- literal case: always take from universal set
        elif node.name == 'literal':
            if node in self.idxs_memo:
                all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                self.take_memo[node].append(wanted_vals)
                checkpoint()
            else:
                # -- initialize idxs_memo to full set
                all_idxs = self.expr_idxs
                n_times = scope.len(all_idxs)
                # -- put array_union into graph for consistency, though it is
                # not necessary
                all_idxs = scope.array_union(all_idxs)
                self.idxs_memo[node] = all_idxs
                all_vals = scope.asarray(scope.repeat(n_times, node))
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                assert node not in self.take_memo
                self.take_memo[node] = [wanted_vals]
                checkpoint()
            return wanted_vals

        # -- switch case: complicated
        elif node.name == 'switch':
            if (node in self.idxs_memo
                and wanted_idxs in self.idxs_memo[node].pos_args):
                # -- phew, easy case
                all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                self.take_memo[node].append(wanted_vals)
                checkpoint()
            else:
                # -- we need to add some indexes
                if node in self.idxs_memo:
                    all_idxs = self.idxs_memo[node]
                    assert all_idxs.name == 'array_union'
                    all_idxs.pos_args.append(wanted_idxs)
                else:
                    all_idxs = scope.array_union(wanted_idxs)

                choice = node.pos_args[0]
                all_choices = self.build_idxs_vals(choice, all_idxs)

                options = node.pos_args[1:]
                args_idxs = scope.vchoice_split(all_idxs, all_choices,
                                                len(options))
                all_vals = scope.vchoice_merge(all_idxs, all_choices)
                for opt_ii, idxs_ii in zip(options, args_idxs):
                    all_vals.pos_args.append(
                        as_apply([
                            idxs_ii,
                            self.build_idxs_vals(opt_ii, idxs_ii),
                        ]))

                wanted_vals = scope.idxs_take(
                    all_idxs,  # -- may grow in future
                    all_vals,  # -- may be replaced in future
                    wanted_idxs)  # -- fixed.
                if node in self.idxs_memo:
                    assert self.idxs_memo[node].name == 'array_union'
                    self.idxs_memo[node].pos_args.append(wanted_idxs)
                    for take in self.take_memo[node]:
                        assert take.name == 'idxs_take'
                        take.pos_args[1] = all_vals
                    self.take_memo[node].append(wanted_vals)
                else:
                    self.idxs_memo[node] = all_idxs
                    self.take_memo[node] = [wanted_vals]
                checkpoint()

        # -- general case
        else:
            # -- this is a general node.
            #    It is generally handled with idxs_memo,
            #    but vectorize_stochastic may immediately transform it into
            #    a more compact form.
            if (node in self.idxs_memo
                and wanted_idxs in self.idxs_memo[node].pos_args):
                # -- phew, easy case
                for take in self.take_memo[node]:
                    if take.pos_args[2] == wanted_idxs:
                        return take
                raise NotImplementedError('how did this happen?')
                #all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                #wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                #self.take_memo[node].append(wanted_vals)
                #checkpoint()
            else:
                # XXX
                # -- determine if wanted_idxs is actually a subset of the idxs
                # that we are already computing.  This is not only an
                # optimization, but prevents the creation of cycles, which
                # would otherwise occur if we have a graph of the form
                # switch(f(a), g(a), 0). If there are other switches inside f
                # and g, does this get trickier?

                # -- assume we need to add some indexes
                checkpoint()
                if node in self.idxs_memo:
                    all_idxs = self.idxs_memo[node]

                else:
                    all_idxs = scope.array_union(wanted_idxs)
                checkpoint()

                all_vals = scope.idxs_map(all_idxs, node.name)
                for ii, aa in enumerate(node.pos_args):
                    all_vals.pos_args.append(as_apply([
                        all_idxs, self.build_idxs_vals(aa, all_idxs)]))
                    checkpoint()
                for ii, (nn, aa) in enumerate(node.named_args):
                    all_vals.named_args.append([nn, as_apply([
                        all_idxs, self.build_idxs_vals(aa, all_idxs)])])
                    checkpoint()
                all_vals = vectorize_stochastic(all_vals)

                checkpoint()
                wanted_vals = scope.idxs_take(
                    all_idxs,  # -- may grow in future
                    all_vals,  # -- may be replaced in future
                    wanted_idxs)  # -- fixed.
                if node in self.idxs_memo:
                    assert self.idxs_memo[node].name == 'array_union'
                    self.idxs_memo[node].pos_args.append(wanted_idxs)
                    toposort(self.idxs_memo[node])
                    # -- this catches the cycle bug mentioned above
                    for take in self.take_memo[node]:
                        assert take.name == 'idxs_take'
                        take.pos_args[1] = all_vals
                    self.take_memo[node].append(wanted_vals)
                else:
                    self.idxs_memo[node] = all_idxs
                    self.take_memo[node] = [wanted_vals]
                checkpoint()

        return wanted_vals
示例#3
0
    def build_idxs_vals(self, node, wanted_idxs):
        """
        This recursive procedure should be called on an output-node.
        """
        checkpoint_asserts = False

        def checkpoint():
            if checkpoint_asserts:
                self.assert_integrity_idxs_take()
                if node in self.idxs_memo:
                    toposort(self.idxs_memo[node])
                if node in self.take_memo:
                    for take in self.take_memo[node]:
                        toposort(take)

        checkpoint()

        # wanted_idxs are fixed, whereas idxs_memo
        # is full of unions, that can grow in subsequent recursive
        # calls to build_idxs_vals with node as argument.
        assert wanted_idxs != self.idxs_memo.get(node)

        # -- easy exit case
        if node.name == 'hyperopt_param':
            # -- ignore, not vectorizing
            return self.build_idxs_vals(node.arg['obj'], wanted_idxs)

        # -- easy exit case
        elif node.name == 'hyperopt_result':
            # -- ignore, not vectorizing
            return self.build_idxs_vals(node.arg['obj'], wanted_idxs)

        # -- literal case: always take from universal set
        elif node.name == 'literal':
            if node in self.idxs_memo:
                all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                self.take_memo[node].append(wanted_vals)
                checkpoint()
            else:
                # -- initialize idxs_memo to full set
                all_idxs = self.expr_idxs
                n_times = scope.len(all_idxs)
                # -- put array_union into graph for consistency, though it is
                # not necessary
                all_idxs = scope.array_union(all_idxs)
                self.idxs_memo[node] = all_idxs
                all_vals = scope.asarray(scope.repeat(n_times, node))
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                assert node not in self.take_memo
                self.take_memo[node] = [wanted_vals]
                checkpoint()
            return wanted_vals

        # -- switch case: complicated
        elif node.name == 'switch':
            if (node in self.idxs_memo
                    and wanted_idxs in self.idxs_memo[node].pos_args):
                # -- phew, easy case
                all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                self.take_memo[node].append(wanted_vals)
                checkpoint()
            else:
                # -- we need to add some indexes
                if node in self.idxs_memo:
                    all_idxs = self.idxs_memo[node]
                    assert all_idxs.name == 'array_union'
                    all_idxs.pos_args.append(wanted_idxs)
                else:
                    all_idxs = scope.array_union(wanted_idxs)

                choice = node.pos_args[0]
                all_choices = self.build_idxs_vals(choice, all_idxs)

                options = node.pos_args[1:]
                args_idxs = scope.vchoice_split(all_idxs, all_choices,
                                                len(options))
                all_vals = scope.vchoice_merge(all_idxs, all_choices)
                for opt_ii, idxs_ii in zip(options, args_idxs):
                    all_vals.pos_args.append(
                        as_apply([
                            idxs_ii,
                            self.build_idxs_vals(opt_ii, idxs_ii),
                        ]))

                wanted_vals = scope.idxs_take(
                    all_idxs,  # -- may grow in future
                    all_vals,  # -- may be replaced in future
                    wanted_idxs)  # -- fixed.
                if node in self.idxs_memo:
                    assert self.idxs_memo[node].name == 'array_union'
                    self.idxs_memo[node].pos_args.append(wanted_idxs)
                    for take in self.take_memo[node]:
                        assert take.name == 'idxs_take'
                        take.pos_args[1] = all_vals
                    self.take_memo[node].append(wanted_vals)
                else:
                    self.idxs_memo[node] = all_idxs
                    self.take_memo[node] = [wanted_vals]
                checkpoint()

        # -- general case
        else:
            # -- this is a general node.
            #    It is generally handled with idxs_memo,
            #    but vectorize_stochastic may immediately transform it into
            #    a more compact form.
            if (node in self.idxs_memo
                    and wanted_idxs in self.idxs_memo[node].pos_args):
                # -- phew, easy case
                for take in self.take_memo[node]:
                    if take.pos_args[2] == wanted_idxs:
                        return take
                raise NotImplementedError('how did this happen?')
                #all_idxs, all_vals = self.take_memo[node][0].pos_args[:2]
                #wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs)
                #self.take_memo[node].append(wanted_vals)
                #checkpoint()
            else:
                # XXX
                # -- determine if wanted_idxs is actually a subset of the idxs
                # that we are already computing.  This is not only an
                # optimization, but prevents the creation of cycles, which
                # would otherwise occur if we have a graph of the form
                # switch(f(a), g(a), 0). If there are other switches inside f
                # and g, does this get trickier?

                # -- assume we need to add some indexes
                checkpoint()
                if node in self.idxs_memo:
                    all_idxs = self.idxs_memo[node]

                else:
                    all_idxs = scope.array_union(wanted_idxs)
                checkpoint()

                all_vals = scope.idxs_map(all_idxs, node.name)
                for ii, aa in enumerate(node.pos_args):
                    all_vals.pos_args.append(
                        as_apply(
                            [all_idxs,
                             self.build_idxs_vals(aa, all_idxs)]))
                    checkpoint()
                for ii, (nn, aa) in enumerate(node.named_args):
                    all_vals.named_args.append([
                        nn,
                        as_apply(
                            [all_idxs,
                             self.build_idxs_vals(aa, all_idxs)])
                    ])
                    checkpoint()
                all_vals = vectorize_stochastic(all_vals)

                checkpoint()
                wanted_vals = scope.idxs_take(
                    all_idxs,  # -- may grow in future
                    all_vals,  # -- may be replaced in future
                    wanted_idxs)  # -- fixed.
                if node in self.idxs_memo:
                    assert self.idxs_memo[node].name == 'array_union'
                    self.idxs_memo[node].pos_args.append(wanted_idxs)
                    toposort(self.idxs_memo[node])
                    # -- this catches the cycle bug mentioned above
                    for take in self.take_memo[node]:
                        assert take.name == 'idxs_take'
                        take.pos_args[1] = all_vals
                    self.take_memo[node].append(wanted_vals)
                else:
                    self.idxs_memo[node] = all_idxs
                    self.take_memo[node] = [wanted_vals]
                checkpoint()

        return wanted_vals