Exemplo n.º 1
0
def test_clone():
    config = config0()
    config2 = clone(config)

    nodeset = set(dfs(config))
    assert not any(n in nodeset for n in dfs(config2))

    foo = recursive_set_rng_kwarg(config, scope.rng_from_seed(5))
    r = rec_eval(foo)
    print r
    r2 = rec_eval(recursive_set_rng_kwarg(config2, scope.rng_from_seed(5)))

    print r2
    assert r == r2
Exemplo n.º 2
0
def test_clone():
    config = config0()
    config2 = clone(config)

    nodeset = set(dfs(config))
    assert not any(n in nodeset for n in dfs(config2))

    foo = recursive_set_rng_kwarg(
                config,
                scope.rng_from_seed(5))
    r = rec_eval(foo)
    print r
    r2 = rec_eval(
            recursive_set_rng_kwarg(
                config2,
                scope.rng_from_seed(5)))

    print r2
    assert r == r2
Exemplo n.º 3
0
    def __init__(self, bandit, seed=seed, cmd=None, workdir=None):
        self.bandit = bandit
        self.seed = seed
        self.rng = np.random.RandomState(self.seed)
        self.cmd = cmd
        self.workdir = workdir
        self.new_ids = ['dummy_id']
        # -- N.B. not necessarily actually a range
        self.s_new_ids = pyll.Literal(self.new_ids)
        self.template_clone_memo = {}
        template = pyll.clone(self.bandit.template, self.template_clone_memo)
        vh = self.vh = VectorizeHelper(template, self.s_new_ids)
        vh.build_idxs()
        vh.build_vals()
        # the keys (nid) here are strings like 'node_5'
        idxs_by_nid = vh.idxs_by_id()
        vals_by_nid = vh.vals_by_id()
        name_by_nid = vh.name_by_id()
        assert set(idxs_by_nid.keys()) == set(vals_by_nid.keys())
        assert set(name_by_nid.keys()) == set(vals_by_nid.keys())

        # -- replace repeat(dist(...)) with vectorized versions
        t_i_v = replace_repeat_stochastic(
                pyll.as_apply([
                    vh.vals_memo[template], idxs_by_nid, vals_by_nid]))
        assert t_i_v.name == 'pos_args'
        template, s_idxs_by_nid, s_vals_by_nid = t_i_v.pos_args
        # -- fetch the dictionaries off the top of the cloned graph
        idxs_by_nid = dict(s_idxs_by_nid.named_args)
        vals_by_nid = dict(s_vals_by_nid.named_args)

        # -- remove non-stochastic nodes from the idxs and vals
        #    because
        #    (a) they should be irrelevant for BanditAlgo operation,
        #    (b) they can be reconstructed from the template and the
        #    stochastic choices, and
        #    (c) they are often annoying when printing / saving.
        for node_id, name in name_by_nid.items():
            if name not in pyll.stochastic.implicit_stochastic_symbols:
                del name_by_nid[node_id]
                del vals_by_nid[node_id]
                del idxs_by_nid[node_id]
            elif name == 'one_of':
                # -- one_of nodes too, because they are duplicates of randint
                del name_by_nid[node_id]
                del vals_by_nid[node_id]
                del idxs_by_nid[node_id]

        # -- make the graph runnable and SON-encodable
        # N.B. operates inplace
        self.s_specs_idxs_vals = recursive_set_rng_kwarg(
                scope.pos_args(template, idxs_by_nid, vals_by_nid),
                pyll.as_apply(self.rng))

        self.vtemplate = template
        self.idxs_by_nid = idxs_by_nid
        self.vals_by_nid = vals_by_nid
        self.name_by_nid = name_by_nid

        # -- compute some document coordinate strings for the node_ids
        pnames = pretty_names(bandit.template, prefix=None)
        doc_coords = self.doc_coords = {}
        for node, pname in pnames.items():
            cnode = self.template_clone_memo[node]
            if cnode.name == 'one_of':
                choice_node = vh.choice_memo[cnode]
                assert choice_node.name == 'randint'
                doc_coords[vh.node_id[choice_node]] = pname #+ '.randint'
            if cnode in vh.node_id and vh.node_id[cnode] in name_by_nid:
                doc_coords[vh.node_id[cnode]] = pname
            else:
                #print 'DROPPING', node
                pass