예제 #1
0
def main():
    # parse arguments
    parser: ArgumentParser = ArgumentParser()
    parser.add_argument('--model_dir', type=str, required=True, help="Directory of nnet model")
    parser.add_argument('--data_dir', type=str, required=True, help="Directory of data")

    parser.add_argument('--env', type=str, required=True, help="Environment: cube3, 15-puzzle, 24-puzzle")
    parser.add_argument('--max_steps', type=int, default=None, help="Maximum number ofsteps to take when solving "
                                                                    "with GBFS. If none is given, then this "
                                                                    "is set to the maximum number of "
                                                                    "backwards steps taken to create the "
                                                                    "data")

    args = parser.parse_args()

    # environment
    env: Environment = env_utils.get_environment(args.env)

    # get device and nnet
    on_gpu: bool
    device: torch.device
    device, devices, on_gpu = nnet_utils.get_device()
    print("device: %s, devices: %s, on_gpu: %s" % (device, devices, on_gpu))

    heuristic_fn = nnet_utils.load_heuristic_fn(args.model_dir, device, on_gpu, env.get_nnet_model(),
                                                env, clip_zero=False)

    gbfs_test(args.data_dir, env, heuristic_fn, max_solve_steps=args.max_steps)
예제 #2
0
def bwas_python(args, env: Environment, states: List[State]):
    # get device
    on_gpu: bool
    device: torch.device
    device, devices, on_gpu = nnet_utils.get_device()

    print("device: %s, devices: %s, on_gpu: %s" % (device, devices, on_gpu))

    heuristic_fn = nnet_utils.load_heuristic_fn(args.model_dir, device, on_gpu, env.get_nnet_model(),
                                                env, clip_zero=True, batch_size=args.nnet_batch_size)

    solns: List[List[int]] = []
    paths: List[List[State]] = []
    times: List = []
    num_nodes_gen: List[int] = []

    for state_idx, state in enumerate(states):
        start_time = time.time()

        num_itrs: int = 0
        astar = AStar([state], env, heuristic_fn, [args.weight])
        while not min(astar.has_found_goal()):
            astar.step(heuristic_fn, args.batch_size, verbose=args.verbose)
            num_itrs += 1

        path: List[State]
        soln: List[int]
        path_cost: float
        num_nodes_gen_idx: int
        goal_node: Node = astar.get_goal_node_smallest_path_cost(0)
        path, soln, path_cost = get_path(goal_node)

        num_nodes_gen_idx: int = astar.get_num_nodes_generated(0)

        solve_time = time.time() - start_time

        # record solution information
        solns.append(soln)
        paths.append(path)
        times.append(solve_time)
        num_nodes_gen.append(num_nodes_gen_idx)

        # check soln
        assert search_utils.is_valid_soln(state, soln, env)

        # print to screen
        timing_str = ", ".join(["%s: %.2f" % (key, val) for key, val in astar.timings.items()])
        print("Times - %s, num_itrs: %i" % (timing_str, num_itrs))

        print("State: %i, SolnCost: %.2f, # Moves: %i, "
              "# Nodes Gen: %s, Time: %.2f" % (state_idx, path_cost, len(soln),
                                               format(num_nodes_gen_idx, ","),
                                               solve_time))

    return solns, paths, times, num_nodes_gen
예제 #3
0
def main():
    # arguments
    parser: ArgumentParser = ArgumentParser()
    args_dict: Dict[str, Any] = parse_arguments(parser)

    if not args_dict["debug"]:
        sys.stdout = data_utils.Logger(args_dict["output_save_loc"], "a")

    # environment
    env: Environment = env_utils.get_environment(args_dict['env'])

    # get device
    on_gpu: bool
    device: torch.device
    device, devices, on_gpu = nnet_utils.get_device()

    print("device: %s, devices: %s, on_gpu: %s" % (device, devices, on_gpu))

    # load nnet
    nnet: nn.Module
    itr: int
    update_num: int
    nnet, itr, update_num = load_nnet(args_dict['curr_dir'], env)

    nnet.to(device)
    if on_gpu and (not args_dict['single_gpu_training']):
        nnet = nn.DataParallel(nnet)

    # training
    while itr < args_dict['max_itrs']:
        # update
        targ_file: str = "%s/model_state_dict.pt" % args_dict['targ_dir']
        all_zeros: bool = not os.path.isfile(targ_file)
        heur_fn_i_q, heur_fn_o_qs, heur_procs = nnet_utils.start_heur_fn_runners(args_dict['num_update_procs'],
                                                                                 args_dict['targ_dir'],
                                                                                 device, on_gpu, env,
                                                                                 all_zeros=all_zeros,
                                                                                 clip_zero=True,
                                                                                 batch_size=args_dict[
                                                                                     "update_nnet_batch_size"])

        states_nnet: List[np.ndarray]
        outputs: np.ndarray
        states_nnet, outputs = do_update(args_dict["back_max"], update_num, env,
                                         args_dict['max_update_steps'], args_dict['update_method'],
                                         args_dict['states_per_update'], args_dict['eps_max'],
                                         heur_fn_i_q, heur_fn_o_qs)

        nnet_utils.stop_heuristic_fn_runners(heur_procs, heur_fn_i_q)

        # train nnet
        num_train_itrs: int = args_dict['epochs_per_update'] * np.ceil(outputs.shape[0] / args_dict['batch_size'])
        print("Training model for update number %i for %i iterations" % (update_num, num_train_itrs))
        last_loss = nnet_utils.train_nnet(nnet, states_nnet, outputs, device, args_dict['batch_size'], num_train_itrs,
                                          itr, args_dict['lr'], args_dict['lr_d'])
        itr += num_train_itrs

        # save nnet
        torch.save(nnet.state_dict(), "%s/model_state_dict.pt" % args_dict['curr_dir'])
        pickle.dump(itr, open("%s/train_itr.pkl" % args_dict['curr_dir'], "wb"), protocol=-1)
        pickle.dump(update_num, open("%s/update_num.pkl" % args_dict['curr_dir'], "wb"), protocol=-1)

        # test
        start_time = time.time()
        heuristic_fn = nnet_utils.get_heuristic_fn(nnet, device, env, batch_size=args_dict['update_nnet_batch_size'])
        max_solve_steps: int = min(update_num + 1, args_dict['back_max'])
        gbfs_test(args_dict['num_test'], args_dict['back_max'], env, heuristic_fn, max_solve_steps=max_solve_steps)

        print("Test time: %.2f" % (time.time() - start_time))

        # clear cuda memory
        torch.cuda.empty_cache()

        print("Last loss was %f" % last_loss)
        if last_loss < args_dict['loss_thresh']:
            # Update nnet
            print("Updating target network")
            copy_files(args_dict['curr_dir'], args_dict['targ_dir'])
            update_num = update_num + 1
            pickle.dump(update_num, open("%s/update_num.pkl" % args_dict['curr_dir'], "wb"), protocol=-1)

    print("Done")
예제 #4
0
def bwas_cpp(args, env: Environment, states: List[State], results_file: str):
    assert (args.env.upper() in [
        'CUBE3', 'CUBE4', 'PUZZLE15', 'PUZZLE24', 'PUZZLE35', 'PUZZLE48',
        'LIGHTSOUT7'
    ])

    # Make c++ socket
    socket_name: str = "%s_cpp_socket" % results_file.split(".")[0]

    try:
        os.unlink(socket_name)
    except OSError:
        if os.path.exists(socket_name):
            raise

    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    sock.bind(socket_name)

    # Get state dimension
    if args.env.upper() == 'CUBE3':
        state_dim: int = 54
    elif args.env.upper() == 'PUZZLE15':
        state_dim: int = 16
    elif args.env.upper() == 'PUZZLE24':
        state_dim: int = 25
    elif args.env.upper() == 'PUZZLE35':
        state_dim: int = 36
    elif args.env.upper() == 'PUZZLE48':
        state_dim: int = 49
    elif args.env.upper() == 'LIGHTSOUT7':
        state_dim: int = 49
    else:
        raise ValueError("Unknown c++ environment: %s" % args.env)

    # start heuristic proc
    num_parallel: int = len(os.environ['CUDA_VISIBLE_DEVICES'].split(","))
    device, devices, on_gpu = nnet_utils.get_device()
    heur_fn_i_q, heur_fn_o_qs, heur_procs = nnet_utils.start_heur_fn_runners(
        num_parallel,
        args.model_dir,
        device,
        on_gpu,
        env,
        all_zeros=False,
        clip_zero=True,
        batch_size=args.nnet_batch_size)
    nnet_utils.heuristic_fn_par(states, env, heur_fn_i_q,
                                heur_fn_o_qs)  # initialize

    heur_proc = Process(target=cpp_listener,
                        args=(sock, args, env, state_dim, heur_fn_i_q,
                              heur_fn_o_qs))
    heur_proc.daemon = True
    heur_proc.start()

    time.sleep(2)  # give socket time to intialize

    solns: List[List[int]] = []
    paths: List[List[State]] = []
    times: List = []
    num_nodes_gen: List[int] = []

    for state_idx, state in enumerate(states):
        # Get string rep of state
        if args.env.upper() == "CUBE3":
            state_str: str = " ".join([str(x) for x in state.colors])
        elif args.env.upper() in [
                "PUZZLE15", "PUZZLE24", "PUZZLE35", "PUZZLE48"
        ]:
            state_str: str = " ".join([str(x) for x in state.tiles])
        elif args.env.upper() in ["LIGHTSOUT7"]:
            state_str: str = " ".join([str(x) for x in state.tiles])
        else:
            raise ValueError("Unknown c++ environment: %s" % args.env)

        popen = Popen([
            './cpp/parallel_weighted_astar', state_str,
            str(args.weight),
            str(args.batch_size), socket_name, args.env, "0"
        ],
                      stdout=PIPE,
                      stderr=PIPE,
                      bufsize=1,
                      universal_newlines=True)
        lines = []
        for stdout_line in iter(popen.stdout.readline, ""):
            stdout_line = stdout_line.strip('\n')
            lines.append(stdout_line)
            if args.verbose:
                sys.stdout.write("%s\n" % stdout_line)
                sys.stdout.flush()

        moves = [int(x) for x in lines[-5].split(" ")[:-1]]
        soln = [x for x in moves][::-1]
        num_nodes_gen_idx = int(lines[-3])
        solve_time = float(lines[-1])

        # record solution information
        path: List[State] = [state]
        next_state: State = state
        transition_costs: List[float] = []

        for move in soln:
            next_states, tcs = env.next_state([next_state], move)

            next_state = next_states[0]
            tc = tcs[0]

            path.append(next_state)
            transition_costs.append(tc)

        solns.append(soln)
        paths.append(path)
        times.append(solve_time)
        num_nodes_gen.append(num_nodes_gen_idx)

        path_cost: float = sum(transition_costs)

        # check soln
        assert search_utils.is_valid_soln(state, soln, env)

        # print to screen
        print("State: %i, SolnCost: %.2f, # Moves: %i, "
              "# Nodes Gen: %s, Time: %.2f" %
              (state_idx, path_cost, len(soln), format(num_nodes_gen_idx,
                                                       ","), solve_time))

    os.unlink(socket_name)

    nnet_utils.stop_heuristic_fn_runners(heur_procs, heur_fn_i_q)

    return solns, paths, times, num_nodes_gen
예제 #5
0
def main():
    # parse arguments
    parser: ArgumentParser = ArgumentParser()
    parser.add_argument('--env', type=str, required=True, help="")
    parser.add_argument('--num_states', type=int, default=100, help="")
    parser.add_argument('--back_max', type=int, default=30, help="")

    args = parser.parse_args()

    # get environment
    env: Environment = env_utils.get_environment(args.env)

    # generate goal states
    start_time = time.time()
    states: List[State] = env.generate_goal_states(args.num_states)

    elapsed_time = time.time() - start_time
    states_per_sec = len(states) / elapsed_time
    print("Generated %i goal states in %s seconds (%.2f/second)" %
          (len(states), elapsed_time, states_per_sec))

    # get data
    start_time = time.time()
    states: List[State]
    states, _ = env.generate_states(args.num_states, (0, args.back_max))

    elapsed_time = time.time() - start_time
    states_per_sec = len(states) / elapsed_time
    print("Generated %i states in %s seconds (%.2f/second)" %
          (len(states), elapsed_time, states_per_sec))

    # expand
    start_time = time.time()
    env.expand(states)
    elapsed_time = time.time() - start_time
    states_per_sec = len(states) / elapsed_time
    print("Expanded %i states in %s seconds (%.2f/second)" %
          (len(states), elapsed_time, states_per_sec))

    # nnet format
    start_time = time.time()

    states_nnet = env.state_to_nnet_input(states)

    elapsed_time = time.time() - start_time
    states_per_sec = len(states) / elapsed_time
    print("Converted %i states to nnet format in "
          "%s seconds (%.2f/second)" %
          (len(states), elapsed_time, states_per_sec))

    # get heuristic fn
    on_gpu: bool
    device: torch.device
    device, devices, on_gpu = nnet_utils.get_device()
    print("device: %s, devices: %s, on_gpu: %s" % (device, devices, on_gpu))

    nnet: nn.Module = env.get_nnet_model()
    nnet.to(device)
    if on_gpu:
        nnet = nn.DataParallel(nnet)

    # nnet initialize
    print("")
    heuristic_fn = nnet_utils.get_heuristic_fn(nnet, device, env)
    heuristic_fn(states)

    # compute
    start_time = time.time()
    heuristic_fn(states)

    nnet_time = time.time() - start_time
    states_per_sec = len(states) / nnet_time
    print("Computed heuristic for %i states in %s seconds (%.2f/second)" %
          (len(states), nnet_time, states_per_sec))

    # multiprocessing
    print("")
    start_time = time.time()
    ctx = get_context("spawn")
    queue1: ctx.Queue = ctx.Queue()
    queue2: ctx.Queue = ctx.Queue()
    proc = ctx.Process(target=data_runner, args=(queue1, queue2))
    proc.daemon = True
    proc.start()
    print("Process start time: %.2f" % (time.time() - start_time))

    queue1.put(states_nnet)
    queue2.get()

    start_time = time.time()
    queue1.put(states_nnet)
    print("State nnet send time: %s" % (time.time() - start_time))

    start_time = time.time()
    queue2.get()
    print("States nnet receive time: %.2f" % (time.time() - start_time))

    start_time = time.time()
    proc.join()
    print("Process join time: %.2f" % (time.time() - start_time))