コード例 #1
0
ファイル: DQN_MCTS_TRAIN.py プロジェクト: syd951186545/BiYe
                    loss = self.loss_func(target_v, Qsa)
                    losses += loss
                self.optimizer.zero_grad()
                losses.backward()
                self.optimizer.step()
                self.writer.add_scalar('loss/value_loss', losses / self.batch_size, self.update_count)
                self.update_count += 1
                if self.update_count % 500 == 0:
                    self.target_net.load_state_dict(self.act_net.state_dict())
                    torch.save(self.act_net.state_dict(), config.act_net_model_dir + str(self.update_count) + ".model")


from BiYeSheJi.Module.environment import Environment

env = Environment(Graph)
env.init_TreeList(config.dataSet + "/train.txt")


def main():
    agentP = DQN()
    find_target = 0
    for i_ep in range(num_episodes):
        root = env.reset()
        if render: env.render()
        # node_t 是经历一次蒙特卡洛树搜索后根据UCT选择的最好下一节点
        path = [root.state.state_tup, ]
        node_t = root
        reward = 0
        while node_t.state.current_node != root.state.target_node:
            node_t = Policy_MCTS(node_t, agentP.act_net)
コード例 #2
0
ファイル: DQN.py プロジェクト: syd951186545/BiYe
from torch.distributions import Normal, Categorical
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from tensorboardX import SummaryWriter

from BiYeSheJi.Module.environment import Environment
from BiYeSheJi.Script.buildNetworkGraph import get_relation_num, get_graph
from BiYeSheJi.Module.state_encoder import StateEncoder
from BiYeSheJi.Script.get_embedding import get_node_emb_matrix, get_relation_emb_dic
from BiYeSheJi.Configuration import config

# Hyper-parameters
seed = 1
render = False
Graph = get_graph()
num_episodes = 400000
env = Environment(Graph)
num_state = env.state_dim
num_action = 128
torch.manual_seed(seed)
random.seed(seed)

Transition = namedtuple('Transition',
                        ['state', 'action', 'reward', 'next_state'])

relation_dim = get_relation_num()
INPUT_SIZE_liner = relation_dim + 128
INPUT_SIZE_GRU = 3 * 128

OUT_SIZE_liner = 128
HIDDEN_SIZE_GRU = relation_dim
コード例 #3
0
import random

from BiYeSheJi.Module.environment import Environment
from BiYeSheJi.Script.buildNetworkGraph import get_graph
import torch
import json

random.seed(1)
start_node_id, target_node_id, query_id = 2, 3, 4
env = Environment(get_graph())
env.reset(start_node_id, target_node_id, query_id)
print(env.st)
for i in range(5):
    action_node_id = random.choice(list(env.Graph[env.current_node].keys()))
    env.step(action_node_id)
    print(env.current_node)
    print(env.path)
    print(env.st)