class Worker(mp.Process): def __init__(self, mode, device, gnet, opt, args, global_ep, global_ep_r, res_queue, pid): super(Worker, self).__init__() self.mode = mode self.device = device self.id = pid self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue self.gnet, self.opt = gnet, opt self.env = LabelEnv(args, self.mode) # will be changed to specific model self.lnet = self.gnet self.target_net = self.lnet # replay memory self.random = random.Random(self.id + args.seed_batch) self.buffer = deque() self.time_step = 0 # episode self.max_ep = args.episode_train if self.mode == 'offline' else args.episode_test def run(self): total_step = 1 ep = 1 while self.g_ep.value < self.max_ep: state = self.env.start(self.id + ep) ep_r = 0 res_cost = [] res_explore = [] res_qvalue = [] res_reward = [] res_acc_test = [] res_acc_valid = [] while True: # play one step explore_flag, action, qvalue = self.lnet.get_action( state, self.device, self.mode) reward, state2, done = self.env.feedback(action) self.push_to_buffer(state, action, reward, state2, done) state = state2 # record results ep_r += reward (acc_test, acc_valid) = self.env.eval_tagger() res_cost.append(len(self.env.queried)) res_explore.append(explore_flag) res_qvalue.append(qvalue) res_reward.append(ep_r) res_acc_test.append(acc_test) res_acc_valid.append(acc_valid) # sync if total_step % UPDATE_GLOBAL_ITER == 0 or done: self.update() print("--{} {}: ep={}, left={}".format( self.device, self.id, self.g_ep.value, state[-1])) if done: self.record(res_cost, res_explore, res_qvalue, res_reward, res_acc_test, res_acc_valid) print('cost: {}'.format(res_cost)) print('explore: {}'.format(res_explore)) print('qvalue: {}'.format(res_qvalue)) print('reward: {}'.format(res_reward)) print('acc_test: {}'.format(res_acc_test)) print('acc_valid: {}'.format(res_acc_valid)) ep += 1 break total_step += 1 self.res_queue.put(None) # push new experience to the buffer def push_to_buffer(self, state, action, reward, state2, done): self.buffer.append((state, action, reward, state2, done)) if len(self.buffer) > REPLAY_BUFFER_SIZE: self.buffer.popleft() # construct training batch (of y and qvalue) def sample_from_buffer(self, batch_size): # experience = (state, action, reward, state2, done) # state = (seq_embeddings, seq_confidences, seq_trellis, tagger_para, queried, scope, rest_budget) q_batch = torch.ones([1, batch_size], dtype=torch.float64).to(self.device) y_batch = torch.ones([1, batch_size], dtype=torch.float64).to(self.device) return q_batch, y_batch def update(self): self.lnet.train() q_batch, y_batch = self.sample_from_buffer(BATCH_SIZE) loss = F.mse_loss(q_batch, y_batch) # set gnet grad = 0 self.opt.zero_grad() # compute lnet gradients loss.backward() for param in self.lnet.parameters(): param.grad.data.clamp_(-1, 1) # torch.nn.utils.clip_grad_norm_(self.l_agent.parameters(), MAX_GD_NORM) # update gnet's grad with lnet's grad for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()): if gp.grad is not None: # if is not cleared (not zero_grad()) return gp._grad = lp.grad # update gnet one step forward self.opt.step() # pull gnet's parameters to local if self.mode == 'offline': self.lnet.load_state_dict(self.gnet.state_dict()) # update target_net if self.time_step % UPDATE_TARGET_ITER == 0: self.target_net = copy.deepcopy(self.lnet) self.time_step += 1 def record(self, res_cost, res_explore, res_qvalue, res_reward, res_acc_test, res_acc_valid): with self.g_ep.get_lock(): self.g_ep.value += 1 res = (self.g_ep.value, res_cost, res_explore, res_qvalue, res_reward, res_acc_test, res_acc_valid) self.res_queue.put(res) # monitor with self.g_ep_r.get_lock(): if self.g_ep_r.value == 0.: self.g_ep_r.value = res_reward[-1] else: self.g_ep_r.value = self.g_ep_r.value * 0.9 + res_reward[ -1] * 0.1 print("*** {} {} complete ep {} | ep_r={}".format( self.device, self.pid, self.g_ep.value, self.g_ep_r.value))
# ======================================== active learning ===================================================== qvalue_list = [] action_mark_list = [] prob_list = [] for seed in samples: env = LabelEnv(dataloader, crf, seed, VALID_N, TEST_N, TRAIN_N, REWEIGHT, BUDGET) agent = AgentParamRNN(GEEDY, max_len, embedding_size, parameter_shape) print (">>>> Start play") step = 0 while env.cost < BUDGET: env.resume() observation = env.get_state() observ = [observation[0], observation[1], observation[3], observation[4], observation[5]] greedy_flg, action, q_value = agent.get_action(observ) reward, observation2, terminal = env.feedback(action) (acc_test, acc_valid) = env.eval_tagger() print ("cost {}: queried {} with greedy {}:{}, acc=({}, {})".format(env.cost, action, greedy_flg, q_value.item(), acc_test, acc_valid)) qvalue_list.append(q_value.item()) action_mark_list.append(greedy_flg) prob_list.append(observation[1][action]) if env.cost % 10 == 0: step += 10 if env.cost < 3: continue for n in range(20 + step): env.reboot() while env.terminal == False: observation = env.get_state()
class WorkerHorizon(mp.Process): def __init__(self, mode, device, gnets, opts, args, global_ep, global_ep_r, res_queue, pid): super(WorkerHorizon, self).__init__() self.mode = mode self.device = device self.id = pid self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue self.gnets, self.opts = gnets, opts self.env = LabelEnv(args, self.mode) self.lnets = [] self.buffers = [] for i in range(args.budget): agent = TrellisCNN(self.env, args).to(self.device) # store all agents agent.load_state_dict(self.gnets[i].state_dict()) self.lnets.append(agent) self.buffers.append(deque()) self.random = random.Random(self.id + args.seed_batch) # episode self.max_ep = args.episode_train if self.mode == 'offline' else args.episode_test def run(self): total_step = 1 ep = 1 while self.g_ep.value < self.max_ep: state = self.env.start(self.id + ep) ep_r = 0 res_cost = [] res_explore = [] res_qvalue = [] res_reward = [] res_acc_test = [] res_acc_valid = [] while True: # play one step horizon = self.env.get_horizon() explore_flag, action, qvalue = self.lnets[horizon - 1].get_action( state, self.device) reward, state2, done = self.env.feedback(action) self.push_to_buffer(state, action, reward, state2, done, horizon) state = state2 # record results ep_r += reward (acc_test, acc_valid) = self.env.eval_tagger() res_cost.append(len(self.env.queried)) res_explore.append(explore_flag) res_qvalue.append(qvalue) res_reward.append(ep_r) res_acc_test.append(acc_test) res_acc_valid.append(acc_valid) # sync self.update(horizon) if total_step % UPDATE_GLOBAL_ITER == 0 or done: print("--{} {}: ep={}, left={}".format( self.device, self.id, self.g_ep.value, state[-1])) if done: self.record(res_cost, res_explore, res_qvalue, res_reward, res_acc_test, res_acc_valid) print('cost: {}'.format(res_cost)) print('explore: {}'.format(res_explore)) print('qvalue: {}'.format(res_qvalue)) print('reward: {}'.format(res_reward)) print('acc_test: {}'.format(res_acc_test)) print('acc_valid: {}'.format(res_acc_valid)) ep += 1 break total_step += 1 self.res_queue.put(None) # push new experience to the buffer def push_to_buffer(self, state, action, reward, state2, done, horizon): self.buffers[horizon - 1].append((state, action, reward, state2, done)) if len(self.buffers[horizon - 1]) > REPLAY_BUFFER_SIZE: self.buffers[horizon - 1].popleft() # construct training batch (of y and qvalue) def sample_from_buffer(self, batch_size, horizon): # experience = (state, action, reward, state2, done) # state = (seq_embeddings, seq_confidences, seq_trellis, tagger_para, queried, scope, rest_budget) minibatch = self.random.sample( self.buffers[horizon - 1], min(len(self.buffers[horizon - 1]), batch_size)) t_batch = torch.from_numpy(np.array([ e[0][2][e[1]] for e in minibatch ])).type(torch.FloatTensor).unsqueeze(1).to(self.device) a_batch = torch.from_numpy(np.array( [e[0][0][e[1]] for e in minibatch])).type(torch.FloatTensor).to(self.device) c_batch = torch.from_numpy( np.array([[e[0][1][e[1]]] for e in minibatch ])).type(torch.FloatTensor).to(self.device) # compute Q(s_t, a) q_batch = self.lnets[horizon - 1](t_batch, a_batch, c_batch) # compute max Q'(s_t+1, a) r_batch = [e[2] for e in minibatch] y_batch = [] for i, e in enumerate(minibatch): if e[4]: y_batch.append(r_batch[i]) else: candidates = [ k for k, idx in enumerate(e[3][5]) if idx not in e[3][4] ] q_values = [] for k in candidates: t = torch.from_numpy(e[3][2][k]).type( torch.FloatTensor).unsqueeze(0).unsqueeze(0).to( self.device) a = torch.from_numpy(e[3][0][k]).type( torch.FloatTensor).unsqueeze(0).to(self.device) c = torch.from_numpy(np.array(e[3][1][k])).type( torch.FloatTensor).unsqueeze(0).to(self.device) q = self.lnets[horizon - 2](t, a, c).detach().item() q_values.append(q) y_batch.append(max(q_values) * GAMMA + r_batch[i]) y_batch = torch.from_numpy(np.array(y_batch)).type( torch.FloatTensor).to(self.device) return q_batch, y_batch def update(self, horizon): self.lnets[horizon - 1].train() q_batch, y_batch = self.sample_from_buffer(BATCH_SIZE, horizon) loss = F.mse_loss(q_batch, y_batch) # set gnet grad = 0 self.opts[horizon - 1].zero_grad() # compute lnet gradients loss.backward() for param in self.lnets[horizon - 1].parameters(): param.grad.data.clamp_(-1, 1) # torch.nn.utils.clip_grad_norm_(self.l_agent.parameters(), MAX_GD_NORM) # update gnet's grad with lnet's grad for lp, gp in zip(self.lnets[horizon - 1].parameters(), self.gnets[horizon - 1].parameters()): if gp.grad is not None: # if is not cleared (not zero_grad()) return gp._grad = lp.grad # update gnet one step forward self.opts[horizon - 1].step() # pull gnet's parameters to local if self.mode == 'offline': self.lnets[horizon - 1].load_state_dict(self.gnets[horizon - 1].state_dict()) def record(self, res_cost, res_explore, res_qvalue, res_reward, res_acc_test, res_acc_valid): with self.g_ep.get_lock(): self.g_ep.value += 1 res = (self.g_ep.value, res_cost, res_explore, res_qvalue, res_reward, res_acc_test, res_acc_valid) self.res_queue.put(res) # monitor with self.g_ep_r.get_lock(): if self.g_ep_r.value == 0.: self.g_ep_r.value = res_reward[-1] else: self.g_ep_r.value = self.g_ep_r.value * 0.9 + res_reward[ -1] * 0.1 print("*** {} {} complete ep {} | ep_r={}".format( self.device, self.pid, self.g_ep.value, self.g_ep_r.value))