def __init__(self, args): self.num_inputs = 856 # 预测网络输入的向量的维度 self.num_output = 5 # 网络输出的向量的维度 self.args = args self.right = 0 self.dataList = [] self.testList = [] # build up the network self.actor_net = ValueNet(self.num_inputs, self.num_output) if self.args.cuda: self.actor_net.cuda() # check some dir if not os.path.exists(self.args.save_dir): os.mkdir(self.args.save_dir) # 读取predicatesEncoded编码 f = open(predicatesEncodeDictPath, 'r') a = f.read() self.predicatesEncodeDict = eval(a) f.close() # 读取所有表名,获取表名数字映射 tables = [] f = open(shortToLongPath, 'r') a = f.read() short_to_long = eval(a) f.close() for i in short_to_long.keys(): tables.append(i) tables.sort() self.table_to_int = {} for i in range(len(tables)): self.table_to_int[tables[i]] = i
def __init__(self, args): # Read dict predicatesEncoded f = open(predicatesEncodeDictPath, 'r') a = f.read() self.predicatesEncodeDict = eval(a) f.close() # Read all tablenames and get tablename-number mapping tables = [] f = open(shortToLongPath, 'r') a = f.read() short_to_long = eval(a) f.close() for i in short_to_long.keys(): tables.append(i) tables.sort() self.table_to_int = {} for i in range(len(tables)): self.table_to_int[tables[i]] = i # The dimension of the network input vector self.num_inputs = len(tables) * len(tables) + len(self.predicatesEncodeDict["1a"]) # The dimension of the vector output by the network self.num_output = 5 self.args = args self.right = 0 # build up the network self.value_net = ValueNet(self.num_inputs, self.num_output) # check some dir if not os.path.exists(self.args.save_dir): os.mkdir(self.args.save_dir) self.dataList = [] self.testList = []
class Selector: # TODO 持つべきは過去の手数の履歴、と現在の状況 def __init__(self, ai_type): self.basic_policy, self.strategy = ai_type.split('-') self.value_net = ValueNet() serializers.load_hdf5(value_net_path, self.value_net) if self.basic_policy == 'slpolicy': self.policy = SLPolicy() serializers.load_hdf5(sl_policy_path1, self.policy) elif self.basic_policy == 'slpolicy2': self.policy = SLPolicy() serializers.load_hdf5(sl_policy_path2, self.policy) elif self.basic_policy == 'valuepolicy': self.policy = self.value_net else: raise ValueError('invalid policy') if self.strategy == 'win': self.traverse_policy = MontecarloPolicy(strategy=1) elif self.strategy == 'draw': self.traverse_policy = MontecarloPolicy(strategy=3) self.count = 0 self.moves_history = [] def act(self, b, color, turn): state = board.to_state(b, color, turn) if state[2].sum() == 0: print('pass') return -1 current_score = self.value_net.predict(self.value_net.xp.array( [state])).data[0] current_score = softmax(current_score) print(np.argmax(current_score) - 20) if self.basic_policy == 'slpolicy' or 'slpolicy2' or 'valuepolicy': action = self.act1(b, color, turn) else: action = -1 return action def act1(self, b, color, turn): stone_cnt = b[0].sum() + b[1].sum() if 64 - stone_cnt > 12: action = self.policy.act(b, color, turn, temperature=1) else: print("TRAVERSE", stone_cnt) action = self.traverse_policy.act(b, color, turn, temperature=1) return action
def __init__(self, ai_type): self.basic_policy, self.strategy = ai_type.split('-') self.value_net = ValueNet() serializers.load_hdf5(value_net_path, self.value_net) if self.basic_policy == 'slpolicy': self.policy = SLPolicy() serializers.load_hdf5(sl_policy_path1, self.policy) elif self.basic_policy == 'slpolicy2': self.policy = SLPolicy() serializers.load_hdf5(sl_policy_path2, self.policy) elif self.basic_policy == 'valuepolicy': self.policy = self.value_net else: raise ValueError('invalid policy') if self.strategy == 'win': self.traverse_policy = MontecarloPolicy(strategy=1) elif self.strategy == 'draw': self.traverse_policy = MontecarloPolicy(strategy=3) self.count = 0 self.moves_history = []
def __init__(self, params: Parameters): self.parms = params self.env = Env(params.game, params.gamma) # Seed self.env.seed(params.seed) np.random.seed(params.seed) tf.random.set_seed(params.seed) self.critic = ValueNet(lr=params.lr_c) self.actor = CtsPolicy(action_bound=self.env.action_bound, action_dim=self.env.num_actions, lr=params.lr_a) tf.summary.trace_on(graph=True)
from __future__ import division import time import math import random from copy import deepcopy import numpy as np from models import ValueNet import torch model_path = './saved_models/supervised.pt' predictionNet = ValueNet(856, 5) predictionNet.load_state_dict( torch.load(model_path, map_location=lambda storage, loc: storage)) predictionNet.eval() def getReward(state): inputState = torch.tensor(state.board + state.predicatesEncode, dtype=torch.float32) with torch.no_grad(): predictionRuntime = predictionNet(inputState) prediction = predictionRuntime.detach().cpu().numpy() maxindex = np.argmax(prediction) reward = (4 - maxindex) / 4.0 return reward def randomPolicy(state): while not state.isTerminal(): try: temp = state.getPossibleActions()
def main(): parser = argparse.ArgumentParser(description='') parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU device ID') parser.add_argument('--epoch', '-e', type=int, default=50, help='# of epoch') parser.add_argument('--batch_size', type=int, default=128, help='size of mini-batch') parser.add_argument('--density', type=int, default=1, help='density of cnn kernel') parser.add_argument('--small', dest='small', action='store_true', default=False) parser.add_argument('--no_bn', dest='use_bn', action='store_false', default=True) parser.add_argument('--out', default='') parser.set_defaults(test=False) args = parser.parse_args() model = ValueNet(use_bn=args.use_bn) # model = RolloutValueNet(use_bn=args.use_bn, output=41) # log directory out = datetime.datetime.now().strftime('%m%d') if args.out: out = out + '_' + args.out out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs_value", out)) os.makedirs(os.path.join(out_dir, 'models'), exist_ok=True) # gpu if args.gpu >= 0: cuda.get_device(args.gpu).use() model.to_gpu() # setting with open(os.path.join(out_dir, 'setting.txt'), 'w') as f: for k, v in args._get_kwargs(): print('{} = {}'.format(k, v)) f.write('{} = {}\n'.format(k, v)) # prepare for dataset if args.small: train = PreprocessedDataset(train_small_path) else: train = PreprocessedDataset(train_path) test = PreprocessedDataset(test_path) train_iter = iterators.SerialIterator(train, args.batch_size) val_iter = iterators.SerialIterator(test, args.batch_size, repeat=False) # optimizer optimizer = chainer.optimizers.Adam(eps=1e-2) optimizer.setup(model) # start training start = time.time() train_count = 0 for epoch in range(args.epoch): # train train_loss = [] train_accuracy = [] for i in range(len(train) // args.batch_size): batch = train_iter.next() x = chainer.Variable( model.xp.array([b[0] for b in batch], 'float32')) y = chainer.Variable(model.xp.array([b[1] for b in batch], 'int32')) optimizer.update(model, x, y) train_count += 1 progress_report(train_count, start, args.batch_size) train_loss.append(cuda.to_cpu(model.loss.data)) train_accuracy.append(cuda.to_cpu(model.accuracy.data)) # test test_loss = [] test_accuracy = [] it = copy.copy(val_iter) for batch in it: x = chainer.Variable(model.xp.array([b[0] for b in batch], 'float32'), volatile=True) y = chainer.Variable(model.xp.array([b[1] for b in batch], 'int32'), volatile=True) model(x, y, train=False) test_loss.append(cuda.to_cpu(model.loss.data)) test_accuracy.append(cuda.to_cpu(model.accuracy.data)) print('\nepoch {} train_loss {:.5f} train_accuracy {:.3f} \n' ' test_loss {:.5f} test_accuracy {:.3f}'.format( epoch, np.mean(train_loss), np.mean(train_accuracy), np.mean(test_loss), np.mean(test_accuracy))) with open(os.path.join(out_dir, "log"), 'a+') as f: f.write( 'epoch {} train_loss {:.5f} train_accuracy {:.3f} \n' ' test_loss {:.5f} test_accuracy {:.3f} \n'.format( epoch, np.mean(train_loss), np.mean(train_accuracy), np.mean(test_loss), np.mean(test_accuracy))) if epoch % 5 == 0: serializers.save_hdf5( os.path.join(out_dir, "models", "value_net_{}.model".format(epoch)), model)
class supervised: def __init__(self, args): self.num_inputs = 856 # 预测网络输入的向量的维度 self.num_output = 5 # 网络输出的向量的维度 self.args = args self.right = 0 self.dataList = [] self.testList = [] # build up the network self.actor_net = ValueNet(self.num_inputs, self.num_output) if self.args.cuda: self.actor_net.cuda() # check some dir if not os.path.exists(self.args.save_dir): os.mkdir(self.args.save_dir) # 读取predicatesEncoded编码 f = open(predicatesEncodeDictPath, 'r') a = f.read() self.predicatesEncodeDict = eval(a) f.close() # 读取所有表名,获取表名数字映射 tables = [] f = open(shortToLongPath, 'r') a = f.read() short_to_long = eval(a) f.close() for i in short_to_long.keys(): tables.append(i) tables.sort() self.table_to_int = {} for i in range(len(tables)): self.table_to_int[tables[i]] = i def load_data(self): if self.dataList.__len__() != 0: return testpath = "./data/testdata.sql" file_test = open(testpath, 'rb') l = pickle.load(file_test) for _ in range(l): self.testList.append(pickle.load(file_test)) file_test.close() trainpath = "./data/traindata.sql" file_train = open(trainpath, 'rb') l = pickle.load(file_train) for _ in range(l): self.dataList.append(pickle.load(file_train)) file_train.close() def print_data(self): for i in range(len(self.dataList)): print(self.dataList[i].state) def hint2matrix(self, hint): # 解构query plan tablesInQuery = hint.split(" ") # 对查询计划编码 matrix = np.mat(np.zeros((28, 28))) stack = [] step = 0 # 按照hint的格式,每次连接结果由一对括号包围,最开始的连接的连接结果(即单表)没有括号 # 可按下公式计算表的数量 num_table = (len(tablesInQuery) + 2) / 3 for i in tablesInQuery: if i == ')': tempb = stack.pop() tempa = stack.pop() # 弹出左括号 _ = stack.pop() # 分割得到已进行连接的表 b = tempb.split('+') a = tempa.split('+') # 排序以后取用左边编号最小的那个表作为代表 b.sort() a.sort() indexb = self.table_to_int[b[0]] indexa = self.table_to_int[a[0]] # 记录本次连接次序 matrix[indexa, indexb] = num_table - step step += 1 # 用'+'表示已进行连接 stack.append(tempa + '+' + tempb) else: stack.append(i) return matrix def pretreatment(self, path): # 统一读入数据 随机抽取进行训练 file_test = open(path) line = file_test.readline() while line: # 解构训练集 queryName = line.split(",")[0].encode('utf-8').decode( 'utf-8-sig').strip() hint = line.split(",")[1] matrix = self.hint2matrix(hint) predicatesEncode = self.predicatesEncodeDict[queryName] state = matrix.flatten().tolist()[0] state = predicatesEncode + state runtime = line.split(",")[2].strip() if runtime == 'timeout': # 5 min = 300 s = 300 000 ms runtime = 300000 else: runtime = int(float(runtime)) temp = data(state, runtime) self.dataList.append(temp) line = file_test.readline() self.dataList.sort(key=lambda x: x.time, reverse=False) for i in range(self.dataList.__len__()): self.dataList[i].label = int( i / (self.dataList.__len__() / self.num_output + 1)) print(self.dataList[i].label) for i in range(int(self.dataList.__len__() * 0.3)): index = random.randint(0, len(self.dataList) - 1) temp = self.dataList.pop(index) self.testList.append(temp) print("size of test set:", len(self.testList), "\tsize of train set:", len(self.dataList)) testpath = "./data/testdata.sql" file_test = open(testpath, 'wb') pickle.dump(len(self.testList), file_test) for value in self.testList: pickle.dump(value, file_test) file_test.close() trainpath = "./data/traindata.sql" file_train = open(trainpath, 'wb') pickle.dump(len(self.dataList), file_train) for value in self.dataList: pickle.dump(value, file_train) file_train.close() def supervised(self): self.load_data() optim = torch.optim.SGD(self.actor_net.parameters(), lr=self.args.critic_lr) # loss_func = torch.nn.CrossEntropyLoss() # loss_func = torch.nn.BCEWithLogitsLoss() # loss_func = torch.nn.MSELoss() loss_func = torch.nn.NLLLoss() loss1000 = 0 count = 0 for step in range(1, 1600001): index = random.randint(0, len(self.dataList) - 1) state = self.dataList[index].state state_tensor = torch.tensor(state, dtype=torch.float32) predictionRuntime = self.actor_net(state_tensor) # 网络预测输出 # temp = [0 for i in range(self.num_output)] # temp[self.dataList[index].label] = 1 # label_tensor = torch.tensor(temp, dtype=torch.float32) # 目标 temp = [self.dataList[index].label] label_tensor = torch.tensor(temp, dtype=torch.long) loss = loss_func(predictionRuntime.view(1, 5), label_tensor) optim.zero_grad() # 清空梯度 loss.backward() # 计算梯度 optim.step() # 应用梯度,并更新参数 loss1000 += loss.item() if step % 100000 == 0: # 每训练100000次,保存当前模型 torch.save( self.actor_net.state_dict(), self.args.save_dir + 'supervised{:d}-{:.5f}.pt'.format(count, loss1000)) count = count + 1 # self.test_network() if step % 1000 == 0: # 每训练1000次,输出1000次训练的总损失 print('[{}] Epoch: {}, Loss: {:.5f}'.format( datetime.now(), step, loss1000)) loss1000 = 0 # functions to test the network def test_network(self): self.load_data() model_path = self.args.save_dir + 'supervised15-569.90684.pt' self.actor_net.load_state_dict( torch.load(model_path, map_location=lambda storage, loc: storage)) self.actor_net.eval() correct = 0 for step in range(self.testList.__len__()): state = self.testList[step].state state_tensor = torch.tensor(state, dtype=torch.float32) predictionRuntime = self.actor_net(state_tensor) prediction = predictionRuntime.detach().cpu().numpy() maxindex = np.argmax(prediction) label = self.testList[step].label if maxindex == label: correct += 1 print(correct, self.testList.__len__(), correct / self.testList.__len__(), end=' ') correct1 = 0 for step in range(self.dataList.__len__()): state = self.dataList[step].state state_tensor = torch.tensor(state, dtype=torch.float32) predictionRuntime = self.actor_net(state_tensor) prediction = predictionRuntime.detach().cpu().numpy() maxindex = np.argmax(prediction) label = self.dataList[step].label if maxindex == label: correct1 += 1 print(correct1, self.dataList.__len__(), correct1 / self.dataList.__len__()) self.right = correct / self.testList.__len__() def test_hintcost(self, queryName, hint): model_path = self.args.save_dir + 'supervised.pt' self.actor_net.load_state_dict( torch.load(model_path, map_location=lambda storage, loc: storage)) self.actor_net.eval() matrix = self.hint2matrix(hint) predicatesEncode = self.predicatesEncodeDict[queryName] state = matrix.flatten().tolist()[0] state = predicatesEncode + state state_tensor = torch.tensor(state, dtype=torch.float32) predictionRuntime = self.actor_net(state_tensor) prediction = predictionRuntime.detach().cpu().numpy() maxindex = np.argmax(prediction) print(maxindex)
class supervised: def __init__(self, args): # Read dict predicatesEncoded f = open(predicatesEncodeDictPath, 'r') a = f.read() self.predicatesEncodeDict = eval(a) f.close() # Read all tablenames and get tablename-number mapping tables = [] f = open(shortToLongPath, 'r') a = f.read() short_to_long = eval(a) f.close() for i in short_to_long.keys(): tables.append(i) tables.sort() self.table_to_int = {} for i in range(len(tables)): self.table_to_int[tables[i]] = i # The dimension of the network input vector self.num_inputs = len(tables) * len(tables) + len(self.predicatesEncodeDict["1a"]) # The dimension of the vector output by the network self.num_output = 5 self.args = args self.right = 0 # build up the network self.value_net = ValueNet(self.num_inputs, self.num_output) # check some dir if not os.path.exists(self.args.save_dir): os.mkdir(self.args.save_dir) self.dataList = [] self.testList = [] # Parsing query plan def hint2matrix(self, hint): tablesInQuery = hint.split(" ") matrix = np.mat(np.zeros((len(self.table_to_int), len(self.table_to_int)))) stack = [] difference = 0 for i in tablesInQuery: if i == ')': tempb = stack.pop() tempa = stack.pop() _ = stack.pop() b = tempb.split('+') a = tempa.split('+') b.sort() a.sort() indexb = self.table_to_int[b[0]] indexa = self.table_to_int[a[0]] matrix[indexa, indexb] = (len(tablesInQuery) + 2) / 3 - difference difference += 1 stack.append(tempa + '+' + tempb) # print(stack) else: stack.append(i) return matrix # Divide training set and test set def pretreatment(self, path): # Load data uniformly and randomly select for training file_test = open(path) line = file_test.readline() while line: queryName = line.split(",")[0].encode('utf-8').decode('utf-8-sig').strip() hint = line.split(",")[1] matrix = self.hint2matrix(hint) predicatesEncode = self.predicatesEncodeDict[queryName] state = matrix.flatten().tolist()[0] state = state + predicatesEncode runtime = line.split(",")[2].strip() if runtime == 'timeout': runtime = ?? # Depends on your settings else: runtime = int(float(runtime)) temp = data(state, runtime) self.dataList.append(temp) line = file_test.readline() self.dataList.sort(key=lambda x: x.time, reverse=False) for i in range(self.dataList.__len__()): self.dataList[i].label = int(i / (self.dataList.__len__() / self.num_output + 1)) # print(self.dataList[i].label) for i in range(int(self.dataList.__len__() * 0.3)): index = random.randint(0, len(self.dataList) - 1) temp = self.dataList.pop(index) self.testList.append(temp) print("size of test set:", len(self.testList), "\tsize of train set:", len(self.dataList)) testpath = "./data/testdata.sql" file_test = open(testpath, 'wb') pickle.dump(len(self.testList), file_test) for value in self.testList: pickle.dump(value, file_test) file_test.close() trainpath = "./data/traindata.sql" file_train = open(trainpath, 'wb') pickle.dump(len(self.dataList), file_train) for value in self.dataList: pickle.dump(value, file_train) file_train.close() # functions to train the network def supervised(self): self.load_data() optim = torch.optim.SGD(self.value_net.parameters(), lr=0.01) # loss_func = torch.nn.MSELoss() # loss_func = torch.nn.CrossEntropyLoss() loss_func = torch.nn.NLLLoss() loss1000 = 0 count = 0 for step in range(1, 16000001): index = random.randint(0, len(self.dataList) - 1) state = self.dataList[index].state state_tensor = torch.tensor(state, dtype=torch.float32) predictionRuntime = torch.log(self.value_net(state_tensor) + 1e-10) predictionRuntime = predictionRuntime.view(1,-1) label = [] label.append(self.dataList[index].label) label_tensor = torch.tensor(label) loss = loss_func(predictionRuntime, label_tensor) optim.zero_grad() loss.backward() optim.step() loss1000 += loss.item() if step % 1000 == 0: print('[{}] Epoch: {}, Loss: {:.5f}'.format(datetime.now(), step, loss1000)) loss1000 = 0 self.test_network() print('[{}] Epoch: {}, Loss: {:.5f}'.format(datetime.now(), step, loss1000)) if step % 200000 == 0: torch.save(self.value_net.state_dict(), self.args.save_dir + 'supervised.pt') self.test_network() # functions to test the network def test_network(self): self.load_data() model_path = self.args.save_dir + 'supervised.pt' self.actor_net.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) self.actor_net.eval() correct = 0 for step in range(self.testList.__len__()): state = self.testList[step].state state_tensor = torch.tensor(state, dtype=torch.float32) predictionRuntime = self.actor_net(state_tensor) prediction = predictionRuntime.detach().cpu().numpy() maxindex = np.argmax(prediction) label = self.testList[step].label #print(maxindex, "\t", label) if maxindex == label: correct += 1 print(correct, self.testList.__len__(), correct/self.testList.__len__(), end = ' ') correct1 = 0 for step in range(self.dataList.__len__()): state = self.dataList[step].state state_tensor = torch.tensor(state, dtype=torch.float32) predictionRuntime = self.actor_net(state_tensor) # prediction = predictionRuntime.detach().cpu().numpy()[0] prediction = predictionRuntime.detach().cpu().numpy() maxindex = np.argmax(prediction) label = self.dataList[step].label #print(maxindex, "\t", label) if maxindex == label: correct1 += 1 print(correct1, self.dataList.__len__(), correct1/self.dataList.__len__()) self.right = correct / self.testList.__len__() def load_data(self): if self.dataList.__len__() != 0: return testpath = "./data/testdata.sql" file_test = open(testpath, 'rb') l = pickle.load(file_test) for _ in range(l): self.testList.append(pickle.load(file_test)) file_test.close() trainpath = "./data/traindata.sql" file_train = open(trainpath, 'rb') l = pickle.load(file_train) for _ in range(l): self.dataList.append(pickle.load(file_train)) file_train.close()
class supervised: def __init__(self, args): self.num_inputs = 856 # 预测网络输入的向量的维度 self.num_output = 2 # 网络输出的向量的维度 self.args = args # build up the network self.actor_net = ValueNet(self.num_inputs, self.num_output) self.actor_net.apply(self.init_weights) if self.args.cuda: self.actor_net.cuda() # check some dir if not os.path.exists(self.args.save_dir): os.mkdir(self.args.save_dir) # 读取字典-predicatesEncoded编码 f = open(queryEncodedDictPath, 'r') a = f.read() self.queryEncodedDict = eval(a) f.close() # 读取所有表名,获取表名数字映射 tables = [] f = open(shortToLongPath, 'r') a = f.read() short_to_long = eval(a) f.close() for i in short_to_long.keys(): tables.append(i) tables.sort() self.table_to_int = {} for i in range(len(tables)): self.table_to_int[tables[i]] = i self.datasetnumber = 10 self.trainList = [] self.testList = [] def pretreatment(self, path): print("Pretreatment running...") # 统一读入数据 随机抽取进行训练 file_test = open(path) line = file_test.readline() dataList = [] while line: # 解构训练集 queryName = line.split(",")[0].encode('utf-8').decode( 'utf-8-sig').strip() state = self.queryEncodedDict[queryName] origintime = int(float(line.split(",")[1].strip())) neotime = int(float(line.split(",")[2].strip())) qpopttime = int(float(line.split(",")[3].strip())) label = int(line.split(",")[4].strip()) temp = data(queryName, state, origintime, neotime, qpopttime, label) dataList.append(temp) line = file_test.readline() random.shuffle(dataList) listtemp = [] for i in range(self.datasetnumber): temptemp = [] listtemp.append(temptemp) for i in range(dataList.__len__()): listtemp[i % listtemp.__len__()].append(dataList[i]) for i in range(listtemp.__len__()): filepath = "./data/data" + str(i) + ".sql" file = open(filepath, 'wb') pickle.dump(len(listtemp[i]), file) for value in listtemp[i]: pickle.dump(value, file) file.close() print("Pretreament data done.") def printdata(self): for i in self.trainList: print(i) def supervised(self): # model_path = self.args.save_dir + 'supervised.pt' # self.actor_net.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) # self.actor_net.eval() self.load_data() optim = torch.optim.SGD(self.actor_net.parameters(), lr=0.0005) loss_func = torch.nn.MSELoss() loss1000 = 0 count = 0 # starttime = datetime.now() for step in range(1, 300001): index = random.randint(0, len(self.trainList) - 1) state = self.trainList[index].state state_tensor = torch.tensor(state, dtype=torch.float32) predictionRuntime = self.actor_net(state_tensor) label = [0 for _ in range(self.num_output)] label[self.trainList[index].label] = 1 label_tensor = torch.tensor(label, dtype=torch.float32) loss = loss_func(predictionRuntime, label_tensor) optim.zero_grad() # 清空梯度 loss.backward() # 计算梯度 optim.step() # 应用梯度,并更新参数 loss1000 += loss.item() if step % 1000 == 0: print('[{}] Epoch: {}, Loss: {:.5f}'.format( datetime.now(), step, loss1000)) loss1000 = 0 # if step % 2000000 == 0: # torch.save(self.actor_net.state_dict(), self.args.save_dir + 'supervised.pt') # self.test_network() torch.save(self.actor_net.state_dict(), self.args.save_dir + 'supervised.pt') # functions to test the network def test_network(self): self.load_data() model_path = self.args.save_dir + 'supervised.pt' self.actor_net.load_state_dict( torch.load(model_path, map_location=lambda storage, loc: storage)) self.actor_net.eval() # 测试集 correct = 0 for step in range(self.testList.__len__()): state = self.testList[step].state state_tensor = torch.tensor(state, dtype=torch.float32) prediction = self.actor_net(state_tensor).detach().cpu().numpy() maxindex = np.argmax(prediction) print(self.testList[step].queryname, ",", self.testList[step].origintime, ",", self.testList[step].neotime, ",", self.testList[step].qpopttime, ",", self.testList[step].label, ",", maxindex) if maxindex == self.testList[step].label: correct += 1 print("测试集:", correct, "\t", self.testList.__len__()) # 训练集 correct1 = 0 for step in range(self.trainList.__len__()): state = self.trainList[step].state state_tensor = torch.tensor(state, dtype=torch.float32) predictionRuntime = self.actor_net(state_tensor) prediction = predictionRuntime.detach().cpu().numpy() maxindex = np.argmax(prediction) label = self.trainList[step].label # print(self.trainList[step].queryname.strip(), "\t", label, "\t", maxindex) if maxindex == label: correct1 += 1 print("训练集", correct1, "\t", self.trainList.__len__()) def load_data(self, testnum=0): if self.trainList.__len__() != 0: return testpath = "./data/data" + str(testnum) + ".sql" file_test = open(testpath, 'rb') l = pickle.load(file_test) for _ in range(l): self.testList.append(pickle.load(file_test)) file_test.close() for i in range(self.datasetnumber): if i == testnum: continue trainpath = "./data/data" + str(i) + ".sql" file_train = open(trainpath, 'rb') l = pickle.load(file_train) for _ in range(l): self.trainList.append(pickle.load(file_train)) file_train.close() print("load data\ttrainSet:", self.trainList.__len__(), " testSet:", self.testList.__len__()) def init_weights(self, m): if type(m) == torch.nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01)