コード例 #1
0
ファイル: nnetSolve.py プロジェクト: wuchiuwong/MoFang
def dataListener(dataQueue, resQueue, gpuNum=None):
    # Environment = env_utils.getEnvironment('cube3')
    nnet = nnet_utils.loadNnet('savedModels/cube3/1/', 'model.meta"', True, Environment, gpuNum=gpuNum)
    while True:
        data = dataQueue.get()
        nnetResult = nnet(data)
        resQueue.put(nnetResult)
コード例 #2
0
def dataListener(dataQueue, resQueue, args, useGPU, Environment, gpuNum=None):
    nnet = nnet_utils.loadNnet(args.model_loc,
                               args.model_name,
                               useGPU,
                               Environment,
                               gpuNum=gpuNum)
    while True:
        data = dataQueue.get()
        nnetResult = nnet(data)
        resQueue.put(nnetResult)
コード例 #3
0
ファイル: tools.py プロジェクト: PanQL/DeepCube
def dataListener(dataQueue, resQueue, gpuNum=None):
    # 导入神经网络模型
    nnet = nnet_utils.loadNnet(model_loc,
                               model_name,
                               False,
                               Environment,
                               gpuNum=gpuNum)
    while True:
        data = dataQueue.get()
        nnetResult = nnet(data)
        resQueue.put(nnetResult)
コード例 #4
0
def heurProc(dataQueue,resQueue,modelLoc,useGPU,Environment,gpuNum=None):
    if modelLoc == "":
        nnet = lambda x,realWorld: np.zeros([x.shape[0]],dtype=np.int)
    else:
        nnet = nnet_utils.loadNnet(modelLoc,"",useGPU,Environment,gpuNum=gpuNum)
    while True:
        data, realWorld = dataQueue.get()
        if data is None:
            resQueue.put(None)
            break
        nnetResult = nnet(data,realWorld=realWorld)
        resQueue.put(nnetResult)
コード例 #5
0
def dataListener(dataQueue, resQueue, gpuNum=None):
    model_loc = os.path.join(BASE_DIR, "DeepCube/savedModels/cube3/1/")
    model_name = 'model.meta'
    Environment = env_utils.getEnvironment('CUBE3')
    #nnet = nnet_utils.loadNnet(model_loc,model_name,useGPU,Environment,gpuNum=gpuNum)
    nnet = nnet_utils.loadNnet(model_loc,
                               model_name,
                               True,
                               Environment,
                               gpuNum=gpuNum)
    while True:
        data = dataQueue.get()
        nnetResult = nnet(data)
        resQueue.put(nnetResult)
コード例 #6
0
ファイル: testNnet.py プロジェクト: HITXZA/cubebot
modelLoc = args.model_loc
numStates = args.num_states
maxTurns = args.max_turns
searchDepth = args.search_depth
numRollouts = args.num_rollouts
solveMethod = args.method.upper()
verbose = args.verbose

if maxTurns is None:
    maxTurns = args.max_s

load_start_time = time.time()
if solveMethod == "BFS" or solveMethod == "MCTS" or solveMethod == "MCTS_SOLVE" or solveMethod == "BESTFS":
    assert (modelLoc != "")
    ### Restore session
    heuristicFn = nnet_utils.loadNnet(args.model_loc, args.model_name,
                                      not args.noGPU, Environment)

print("Loaded: %s" % (time.time() - load_start_time))

### Run network on different scrambles
scrambleTests = range(args.min_s, args.max_s, 1)
if args.max_s - args.min_s > 30:
    scrambleTests = np.linspace(args.min_s, args.max_s, 30, dtype=np.int)

for scrambleNum in scrambleTests:
    solve_start_time = time.time()
    # Solve cubes
    testStates_cube, _ = Environment.generate_envs(numStates,
                                                   [scrambleNum, scrambleNum])
    testStates = Environment.state_to_nnet_input(np.stack(testStates_cube))
    if solveMethod == "BFS" or solveMethod == "MCTS":
コード例 #7
0
ファイル: LightsOut.py プロジェクト: HITXZA/cubebot
        for move in moves:
            self.state = self.env.next_state(self.state, move)
            self._updatePlot()
            time.sleep(0.5)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--N', type=int, default=5, help="")
    parser.add_argument('--heur', type=str, default=None, help="")
    parser.add_argument('--init', type=str, default=None, help="")
    args = parser.parse_args()

    env = LightsOut(args.N)
    if args.init is None:
        state = np.array([env.solvedState])
    else:
        state = np.array(
            [[int(x) for x in list(re.sub("[^0-9]", "", args.init))]])
    #state = np.array(env.generate_envs(1, [100, 100])[0])

    heuristicFn = None
    if args.heur is not None:
        heuristicFn = nnet_utils.loadNnet(args.heur, "", False, env)

    fig = plt.figure(figsize=(5, 5))
    interactiveEnv = InteractiveEnv(state, env, heuristicFn)
    fig.add_axes(interactiveEnv)

    plt.show()