Ejemplo n.º 1
0
def run_trial(num_bases, num_blocks, max_levels):

    env = BlocksWorldEnv(show=False)

    # rejection sample non-trivial instance
    thing_below, goal_thing_below = random_problem_instance(
        env, num_blocks, max_levels, num_bases)

    am = make_abstract_machine(env, num_bases, max_levels)
    nvm = virtualize(am)

    am_results = run_machine(am, goal_thing_below, {"jnt": "rest"})

    env.reset()
    env.load_blocks(thing_below, num_bases)

    nvm_results = run_machine(nvm, goal_thing_below,
                              {"jnt": tr.tensor(am.ik["rest"]).float()})

    env.close()

    return am_results, nvm_results, nvm.size(), thing_below, goal_thing_below
            def step_hook(self, env, action):
                self.mp.append(env.movement_penalty())
                self.sym.append(
                    compute_symbolic_reward(env, self.goal_thing_below))

        # load
        tracker = Tracker(goal_thing_below)
        env = BlocksWorldEnv(show=False, step_hook=tracker.step_hook)
        env.load_blocks(thing_below)
        # run rvm
        rvm = make_abstract_machine(env,
                                    num_bases,
                                    max_levels,
                                    gen_regs=["r0", "r1"])
        nvm = virtualize(rvm, nv.default_activator)
        # run
        goal_thing_above = env.invert(goal_thing_below)
        for key, val in goal_thing_above.items():
            if val == "none": goal_thing_above[key] = "nil"
        memorize_env(rvm, goal_thing_above)
        rvm.reset({"jnt": "rest"})
        rvm.mount("main")
        while True:
            done = rvm.tick()
            if done: break
        # run_machine(rvm, goal_thing_below, reset_dict={"jnt": "rest"})
        num_time_steps = rvm.tick_counter
        env.close()
        # save
        with open("fcase_batched_data.pkl", "wb") as f:
                start_rep = time.perf_counter()
                results.append([])

                if prob_freq != "once":
                    problem = domain.random_problem_instance()
                env = BlocksWorldEnv(show=False,
                                     step_hook=penalty_tracker.step_hook)
                env.load_blocks(problem.thing_below)

                # set up rvm and virtualize
                rvm = make_abstract_machine(env, domain)
                rvm.reset({"jnt": "rest"})
                rvm.mount("main")

                nvm = virtualize(rvm,
                                 σ=nv.default_activator,
                                 detach_gates=detach_gates)
                nvm.mount("main")
                W_init = {
                    name: {
                        0: nvm.net.batchify_weights(conn.W)
                    }
                    for name, conn in nvm.connections.items()
                }
                v_init = {
                    name: {
                        0: nvm.net.batchify_activities(reg.content)
                    }
                    for name, reg in nvm.registers.items()
                }
                v_init["jnt"][0] = nvm.net.batchify_activities(
Ejemplo n.º 4
0
                env = BlocksWorldEnv(show=showenv,
                                     step_hook=penalty_tracker.step_hook)
                env.load_blocks({
                    "b%d" % n: "t%d" % n
                    for n in range(num_bases)
                })  # placeholder for rvm construction

                # set up rvm and virtualize
                rvm = make_abstract_machine(env,
                                            num_bases,
                                            max_levels,
                                            gen_regs=["r0", "r1"])
                rvm.reset({"jnt": "rest"})
                rvm.mount("main")

                nvm = virtualize(rvm, σ)
                init_regs, init_conns = nvm.get_state()
                init_regs["jnt"] = tr.tensor(rvm.ik["rest"]).float()

                # set up trainable connections
                conn_params = {name: init_conns[name] for name in trainable}
                for p in conn_params.values():
                    p.requires_grad_()
                opt = tr.optim.Adam(conn_params.values(), lr=learning_rate)

                # save original values for comparison
                orig_conns = {
                    name: init_conns[name].detach().clone()
                    for name in trainable
                }
Ejemplo n.º 5
0
def run_trial(domain):

    env = BlocksWorldEnv(show=False)

    # rejection sample non-trivial instance
    problem = domain.random_problem_instance()
    env.reset()
    env.load_blocks(problem.thing_below, num_bases=domain.num_bases)

    # set up rvm and virtualize
    rvm = make_abstract_machine(env, domain)
    memorize_problem(rvm, problem)
    rvm.reset({"jnt": "rest"})
    rvm.mount("main")
    nvm = virtualize(rvm, σ=nv.default_activator, detach_gates=True)
    nvm.mount("main")
    W_init = {
        name: {
            0: nvm.net.batchify_weights(conn.W)
        }
        for name, conn in nvm.connections.items()
    }
    v_init = {
        name: {
            0: nvm.net.batchify_activities(reg.content)
        }
        for name, reg in nvm.registers.items()
    }
    v_init["jnt"][0] = nvm.net.batchify_activities(
        tr.tensor(rvm.ik["rest"]).float())

    # rvm_results = run_machine(rvm, problem.goal_thing_below, {"jnt": "rest"})
    start = time.perf_counter()
    tar_changed = False
    while True:
        done = rvm.tick()
        if tar_changed:
            position = rvm.ik[rvm.registers["jnt"].content]
            env.goto_position(position, speed=1.5)
        if done: break
        tar_changed = (rvm.registers["tar"].content !=
                       rvm.registers["tar"].old_content)
    rvm_ticks = rvm.tick_counter
    rvm_runtime = time.perf_counter() - start
    rvm_sym = compute_symbolic_reward(env, problem.goal_thing_below)
    rvm_spa = compute_spatial_reward(env, problem.goal_thing_below)
    rvm_results = rvm_ticks, rvm_runtime, rvm_sym, rvm_spa

    # nvm_results = run_machine(nvm, problem.goal_thing_below, {"jnt": tr.tensor(rvm.ik["rest"]).float()})
    env.reset()
    env.load_blocks(problem.thing_below, num_bases=domain.num_bases)
    start = time.perf_counter()
    while True:
        t = nvm.net.tick_counter
        if t > 0 and nvm.decode("ipt", t, 0) == nvm.decode("ipt", t - 1, 0):
            break
        nvm.net.tick(W_init, v_init)
        nvm.pullback(t)
        if t > 1 and nvm.decode("tar", t - 2, 0) != nvm.decode(
                "tar", t - 1, 0):
            position = nvm.net.activities["jnt"][t][0, :, 0].detach().numpy()
            env.goto_position(position, speed=1.5)
    nvm_ticks = nvm.net.tick_counter
    nvm_runtime = time.perf_counter() - start
    nvm_sym = compute_symbolic_reward(env, problem.goal_thing_below)
    nvm_spa = compute_spatial_reward(env, problem.goal_thing_below)
    nvm_results = nvm_ticks, nvm_runtime, nvm_sym, nvm_spa

    env.close()
    return rvm_results, nvm_results, nvm.size(), problem
Ejemplo n.º 6
0
    showresults = True
    # tr.autograd.set_detect_anomaly(True)

    if run_exp:

        results = []
        for rep in range(num_repetitions):
            start_rep = time.perf_counter()
            results.append([])

            env = BlocksWorldEnv(show=False)
            # placehold blocks for nvm init
            env.load_blocks({"b%d" % n: "t%d" % n for n in range(num_bases)})

            rvm = make_abstract_machine(env, num_bases, max_levels)
            nvm = virtualize(rvm)
            # print(nvm.size())
            # input('.')
            init_regs, init_conns = nvm.get_state()
            orig_ik_W = init_conns["ik"].clone()
            init_regs["jnt"] = tr.tensor(rvm.ik["rest"]).float()

            # set up trainable connections
            conn_params = {
                name: init_conns[name]
                # for name in ["ik", "to", "tc", "pc", "pc"]
                for name in ["ik"]
            }
            for p in conn_params.values():
                p.requires_grad_()
            opt = tr.optim.Adam(conn_params.values(), lr=0.00001)