コード例 #1
0
ファイル: untitled3.py プロジェクト: ai4ce/SNAC
import gym
from DMP_Env_1D_static_MCTS import deep_mobile_printing_1d1r_MCTS
import matplotlib.pyplot as plt
import os
import uct
import numpy as np
import time

save_path = "./log/"
if os.path.exists(save_path) == False:
    os.makedirs(save_path)
### Parameters
env = deep_mobile_printing_1d1r_MCTS(plan_choose=0)
agent = uct.UCT(action_space=env.action_space,
                rollouts=2,
                horizon=0,
                ucb_constant=0.5,
                is_model_dynamic=True)

### Run
# =============================================================================
# env.reset()
# done = False
#
# fig = plt.figure(figsize=(5, 5))
# ax = fig.add_subplot(1, 1, 1)
# print(env.total_brick)
# print(env.one_hot)
# ax.clear()
# Total_reward=0
# while True:
コード例 #2
0
 def next_move(self):
     return uct.UCT(self.state, self.n_iter, self.verbose)
コード例 #3
0
ファイル: MCTS_DQN_static.py プロジェクト: ai4ce/SNAC
    def forward(self, s, a):
        x = torch.cat((s, a), dim=1)
        x = self.fc_1(x)
        x = self.relu(x)
        x = self.fc_2(x)
        x = self.relu(x)
        x = self.fc_3(x)
        x = self.relu(x)
        Q = self.out(x)
        return Q


UCT_mcts = uct.UCT(action_space=env.action_space,
                   rollouts=ROLLOUT,
                   horizon=100,
                   ucb_constant=UCB_CONSTANT,
                   is_model_dynamic=True)


class DQN_AGNET():
    def __init__(self, device):
        self.Eval_net = Q_NET().to(device)
        self.Target_net = Q_NET().to(device)
        self.device = device
        self.learn_step = 0  # counting the number of learning for update traget periodiclly
        self.count_memory = 0  # counting the transitions
        # self.replay_memory = np.zeros((Replay_memory_size, State_dim * 2 + 2))
        self.replay_memory = deque(maxlen=Replay_memory_size)
        self.optimizer = torch.optim.Adam(self.Eval_net.parameters(), lr=Lr)
        self.greedy_epsilon = 0.2
コード例 #4
0
ファイル: test_MCTS_DQN.py プロジェクト: ai4ce/SNAC
        self.relu = nn.ReLU()
    def forward(self, s,a):
        x=torch.cat((s,a),dim=1)
        x = self.fc_1(x)
        x = self.relu(x)
        x = self.fc_2(x)
        x = self.relu(x)
        x = self.fc_3(x)
        x = self.relu(x)
        Q = self.out(x)
        return Q

UCT_mcts = uct.UCT(
    action_space=env.action_space,
    rollouts=20,
    horizon=100,
    ucb_constant=0.5,
    is_model_dynamic=True
)

class DQN_AGNET():
    def __init__(self,device):
        self.Eval_net= Q_NET().to(device)
        self.Target_net = Q_NET().to(device)
        self.device=device
        self.learn_step = 0                                     # counting the number of learning for update traget periodiclly 
        self.count_memory = 0                                         # counting the transitions 
        # self.replay_memory = np.zeros((Replay_memory_size, State_dim * 2 + 2)) 
        self.replay_memory = deque(maxlen=Replay_memory_size)    
        self.optimizer = torch.optim.Adam(self.Eval_net.parameters(), lr=Lr)
        self.greedy_epsilon=0.2