예제 #1
0
 def __init__(self, mode, device, gnet, opt, args, global_ep, global_ep_r,
              res_queue, pid):
     super(Worker, self).__init__()
     self.mode = mode
     self.device = device
     self.id = pid
     self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
     self.gnet, self.opt = gnet, opt
     self.env = LabelEnv(args, self.mode)
     # will be changed to specific model
     self.lnet = self.gnet
     self.target_net = self.lnet
     # replay memory
     self.random = random.Random(self.id + args.seed_batch)
     self.buffer = deque()
     self.time_step = 0
     # episode
     self.max_ep = args.episode_train if self.mode == 'offline' else args.episode_test
예제 #2
0
 def __init__(self, mode, device, gnets, opts, args, global_ep, global_ep_r,
              res_queue, pid):
     super(WorkerHorizon, self).__init__()
     self.mode = mode
     self.device = device
     self.id = pid
     self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
     self.gnets, self.opts = gnets, opts
     self.env = LabelEnv(args, self.mode)
     self.lnets = []
     self.buffers = []
     for i in range(args.budget):
         agent = TrellisCNN(self.env, args).to(self.device)
         # store all agents
         agent.load_state_dict(self.gnets[i].state_dict())
         self.lnets.append(agent)
         self.buffers.append(deque())
     self.random = random.Random(self.id + args.seed_batch)
     # episode
     self.max_ep = args.episode_train if self.mode == 'offline' else args.episode_test
예제 #3
0
파일: run.py 프로젝트: MyTHWN/active-RL
TRAIN_N = 600
REWEIGHT = args.reward
GEEDY = args.greedy
BUDGET = args.budget
max_len = dataloader.get_max_len()
embedding_size = dataloader.get_embed_size()
parameter_shape = crf.get_para_shape()
print ("max_len is: {}".format(max_len))
print ("crf para size: {}".format(parameter_shape))

# ======================================== active learning =====================================================
qvalue_list = []
action_mark_list = []
prob_list = []
for seed in samples:
    env = LabelEnv(dataloader, crf, seed, VALID_N, TEST_N, TRAIN_N, REWEIGHT, BUDGET)
    agent = AgentParamRNN(GEEDY, max_len, embedding_size, parameter_shape)
    
    print (">>>> Start play")
    step = 0
    while env.cost < BUDGET:
        env.resume()
        observation = env.get_state()
        observ = [observation[0], observation[1], observation[3], observation[4], observation[5]] 
        greedy_flg, action, q_value = agent.get_action(observ)
        reward, observation2, terminal = env.feedback(action)
        (acc_test, acc_valid) = env.eval_tagger()
        print ("cost {}: queried {} with greedy {}:{}, acc=({}, {})".format(env.cost, action, greedy_flg, 
                                                                            q_value.item(), acc_test, acc_valid))
        qvalue_list.append(q_value.item())
        action_mark_list.append(greedy_flg)
예제 #4
0
class Worker(mp.Process):
    def __init__(self, mode, device, gnet, opt, args, global_ep, global_ep_r,
                 res_queue, pid):
        super(Worker, self).__init__()
        self.mode = mode
        self.device = device
        self.id = pid
        self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
        self.gnet, self.opt = gnet, opt
        self.env = LabelEnv(args, self.mode)
        # will be changed to specific model
        self.lnet = self.gnet
        self.target_net = self.lnet
        # replay memory
        self.random = random.Random(self.id + args.seed_batch)
        self.buffer = deque()
        self.time_step = 0
        # episode
        self.max_ep = args.episode_train if self.mode == 'offline' else args.episode_test

    def run(self):
        total_step = 1
        ep = 1
        while self.g_ep.value < self.max_ep:
            state = self.env.start(self.id + ep)
            ep_r = 0
            res_cost = []
            res_explore = []
            res_qvalue = []
            res_reward = []
            res_acc_test = []
            res_acc_valid = []
            while True:
                # play one step
                explore_flag, action, qvalue = self.lnet.get_action(
                    state, self.device, self.mode)
                reward, state2, done = self.env.feedback(action)
                self.push_to_buffer(state, action, reward, state2, done)
                state = state2
                # record results
                ep_r += reward
                (acc_test, acc_valid) = self.env.eval_tagger()
                res_cost.append(len(self.env.queried))
                res_explore.append(explore_flag)
                res_qvalue.append(qvalue)
                res_reward.append(ep_r)
                res_acc_test.append(acc_test)
                res_acc_valid.append(acc_valid)
                # sync
                if total_step % UPDATE_GLOBAL_ITER == 0 or done:
                    self.update()
                    print("--{} {}: ep={}, left={}".format(
                        self.device, self.id, self.g_ep.value, state[-1]))
                if done:
                    self.record(res_cost, res_explore, res_qvalue, res_reward,
                                res_acc_test, res_acc_valid)
                    print('cost: {}'.format(res_cost))
                    print('explore: {}'.format(res_explore))
                    print('qvalue: {}'.format(res_qvalue))
                    print('reward: {}'.format(res_reward))
                    print('acc_test: {}'.format(res_acc_test))
                    print('acc_valid: {}'.format(res_acc_valid))
                    ep += 1
                    break
                total_step += 1
        self.res_queue.put(None)

    # push new experience to the buffer
    def push_to_buffer(self, state, action, reward, state2, done):
        self.buffer.append((state, action, reward, state2, done))
        if len(self.buffer) > REPLAY_BUFFER_SIZE:
            self.buffer.popleft()

    # construct training batch (of y and qvalue)
    def sample_from_buffer(self, batch_size):
        # experience = (state, action, reward, state2, done)
        # state = (seq_embeddings, seq_confidences, seq_trellis, tagger_para, queried, scope, rest_budget)
        q_batch = torch.ones([1, batch_size],
                             dtype=torch.float64).to(self.device)
        y_batch = torch.ones([1, batch_size],
                             dtype=torch.float64).to(self.device)
        return q_batch, y_batch

    def update(self):
        self.lnet.train()
        q_batch, y_batch = self.sample_from_buffer(BATCH_SIZE)
        loss = F.mse_loss(q_batch, y_batch)
        # set gnet grad = 0
        self.opt.zero_grad()
        # compute lnet gradients
        loss.backward()
        for param in self.lnet.parameters():
            param.grad.data.clamp_(-1, 1)


#         torch.nn.utils.clip_grad_norm_(self.l_agent.parameters(), MAX_GD_NORM)
# update gnet's grad with lnet's grad
        for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()):
            if gp.grad is not None:  # if is not cleared (not zero_grad())
                return
            gp._grad = lp.grad
        # update gnet one step forward
        self.opt.step()
        # pull gnet's parameters to local
        if self.mode == 'offline':
            self.lnet.load_state_dict(self.gnet.state_dict())
        # update target_net
        if self.time_step % UPDATE_TARGET_ITER == 0:
            self.target_net = copy.deepcopy(self.lnet)
        self.time_step += 1

    def record(self, res_cost, res_explore, res_qvalue, res_reward,
               res_acc_test, res_acc_valid):
        with self.g_ep.get_lock():
            self.g_ep.value += 1
        res = (self.g_ep.value, res_cost, res_explore, res_qvalue, res_reward,
               res_acc_test, res_acc_valid)
        self.res_queue.put(res)
        # monitor
        with self.g_ep_r.get_lock():
            if self.g_ep_r.value == 0.:
                self.g_ep_r.value = res_reward[-1]
            else:
                self.g_ep_r.value = self.g_ep_r.value * 0.9 + res_reward[
                    -1] * 0.1
        print("*** {} {} complete ep {} | ep_r={}".format(
            self.device, self.pid, self.g_ep.value, self.g_ep_r.value))
예제 #5
0
def main():
    args = parser.parse_args()

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        device = torch.device("cuda:{}".format(args.cuda))
    else:
        device = torch.device("cpu")

    # important! Without this, lnet in worker cannot forward
    # spawn: for unix and linux; fork: for unix only
    mp.set_start_method('spawn')

    # === multiprocessing ====
    # global agents with different budgets
    agents = []
    opts = []
    for i in range(args.budget):
        agent = TrellisCNN(LabelEnv(args, None), args).to(device)
        # share the global parameters
        agent.share_memory()
        # optimizer for global model
        opt = SharedAdam(agent.parameters(), lr=0.001)
        # store all agents
        agents.append(agent)
        opts.append(opt)
        if i == 0:
            para_size = sum(p.numel() for p in agent.parameters()
                            if p.requires_grad)
            print('global parameter size={}*{},'.format(
                para_size, args.budget))

    # offline train agent for args.episode_train rounds
    start_time = time.time()
    global_ep, global_ep_r, res_queue = mp.Value('i',
                                                 0), mp.Value('d',
                                                              0.), mp.Queue()
    tr_workers = [
        WorkerHorizon('offline', device, agents, opts, args, global_ep,
                      global_ep_r, res_queue, pid) for pid in range(5)
    ]
    [w.start() for w in tr_workers]
    tr_result = []
    while True:
        res = res_queue.get()
        if res is not None:
            tr_result.append(res)
        else:
            break
    [w.join() for w in tr_workers]
    print("Training Done! Cost {} for {} episodes.".format(
        time.time() - start_time, args.episode_train))

    # online test agent for args.episode_test rounds
    start_time = time.time()
    global_ep, global_ep_r, res_queue = mp.Value('i',
                                                 0), mp.Value('d',
                                                              0.), mp.Queue()
    ts_workers = [
        WorkerHorizon('online', device, agents, opts, args, global_ep,
                      global_ep_r, res_queue, pid) for pid in range(5)
    ]
    [w.start() for w in ts_workers]
    ts_result = []
    while True:
        res = res_queue.get()
        if res is not None:
            ts_result.append(res)
        else:
            break
    [w.join() for w in ts_workers]
    print("Testing Done! Cost {} for {} episodes.".format(
        time.time() - start_time, args.episode_test))

    num = "num" if args.num_flag else ""
    emb = "embed" if args.embed_flag else ""
    filename = "./results_mp/" + args.data + num + emb + "_" + args.model + "_horizon_" \
                + str(args.budget) + "bgt_" + str(args.init) + "init_" \
                + str(args.episode_train) + "trainEp_" + str(args.episode_test) + "testEp"

    with open(filename + ".mp", "wb") as result:
        # format:
        # tr_result = [res]
        # res = (g_ep, cost, explore, qvalue, r, acc_test, acc_valid)
        pickle.dump((tr_result, ts_result), result)
예제 #6
0
class WorkerHorizon(mp.Process):
    def __init__(self, mode, device, gnets, opts, args, global_ep, global_ep_r,
                 res_queue, pid):
        super(WorkerHorizon, self).__init__()
        self.mode = mode
        self.device = device
        self.id = pid
        self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
        self.gnets, self.opts = gnets, opts
        self.env = LabelEnv(args, self.mode)
        self.lnets = []
        self.buffers = []
        for i in range(args.budget):
            agent = TrellisCNN(self.env, args).to(self.device)
            # store all agents
            agent.load_state_dict(self.gnets[i].state_dict())
            self.lnets.append(agent)
            self.buffers.append(deque())
        self.random = random.Random(self.id + args.seed_batch)
        # episode
        self.max_ep = args.episode_train if self.mode == 'offline' else args.episode_test

    def run(self):
        total_step = 1
        ep = 1
        while self.g_ep.value < self.max_ep:
            state = self.env.start(self.id + ep)
            ep_r = 0
            res_cost = []
            res_explore = []
            res_qvalue = []
            res_reward = []
            res_acc_test = []
            res_acc_valid = []
            while True:
                # play one step
                horizon = self.env.get_horizon()
                explore_flag, action, qvalue = self.lnets[horizon -
                                                          1].get_action(
                                                              state,
                                                              self.device)
                reward, state2, done = self.env.feedback(action)
                self.push_to_buffer(state, action, reward, state2, done,
                                    horizon)
                state = state2
                # record results
                ep_r += reward
                (acc_test, acc_valid) = self.env.eval_tagger()
                res_cost.append(len(self.env.queried))
                res_explore.append(explore_flag)
                res_qvalue.append(qvalue)
                res_reward.append(ep_r)
                res_acc_test.append(acc_test)
                res_acc_valid.append(acc_valid)
                # sync
                self.update(horizon)
                if total_step % UPDATE_GLOBAL_ITER == 0 or done:
                    print("--{} {}: ep={}, left={}".format(
                        self.device, self.id, self.g_ep.value, state[-1]))
                if done:
                    self.record(res_cost, res_explore, res_qvalue, res_reward,
                                res_acc_test, res_acc_valid)
                    print('cost: {}'.format(res_cost))
                    print('explore: {}'.format(res_explore))
                    print('qvalue: {}'.format(res_qvalue))
                    print('reward: {}'.format(res_reward))
                    print('acc_test: {}'.format(res_acc_test))
                    print('acc_valid: {}'.format(res_acc_valid))
                    ep += 1
                    break
                total_step += 1
        self.res_queue.put(None)

    # push new experience to the buffer
    def push_to_buffer(self, state, action, reward, state2, done, horizon):
        self.buffers[horizon - 1].append((state, action, reward, state2, done))
        if len(self.buffers[horizon - 1]) > REPLAY_BUFFER_SIZE:
            self.buffers[horizon - 1].popleft()

    # construct training batch (of y and qvalue)
    def sample_from_buffer(self, batch_size, horizon):
        # experience = (state, action, reward, state2, done)
        # state = (seq_embeddings, seq_confidences, seq_trellis, tagger_para, queried, scope, rest_budget)
        minibatch = self.random.sample(
            self.buffers[horizon - 1],
            min(len(self.buffers[horizon - 1]), batch_size))
        t_batch = torch.from_numpy(np.array([
            e[0][2][e[1]] for e in minibatch
        ])).type(torch.FloatTensor).unsqueeze(1).to(self.device)
        a_batch = torch.from_numpy(np.array(
            [e[0][0][e[1]]
             for e in minibatch])).type(torch.FloatTensor).to(self.device)
        c_batch = torch.from_numpy(
            np.array([[e[0][1][e[1]]] for e in minibatch
                      ])).type(torch.FloatTensor).to(self.device)
        # compute Q(s_t, a)
        q_batch = self.lnets[horizon - 1](t_batch, a_batch, c_batch)
        # compute max Q'(s_t+1, a)
        r_batch = [e[2] for e in minibatch]
        y_batch = []
        for i, e in enumerate(minibatch):
            if e[4]:
                y_batch.append(r_batch[i])
            else:
                candidates = [
                    k for k, idx in enumerate(e[3][5]) if idx not in e[3][4]
                ]
                q_values = []
                for k in candidates:
                    t = torch.from_numpy(e[3][2][k]).type(
                        torch.FloatTensor).unsqueeze(0).unsqueeze(0).to(
                            self.device)
                    a = torch.from_numpy(e[3][0][k]).type(
                        torch.FloatTensor).unsqueeze(0).to(self.device)
                    c = torch.from_numpy(np.array(e[3][1][k])).type(
                        torch.FloatTensor).unsqueeze(0).to(self.device)
                    q = self.lnets[horizon - 2](t, a, c).detach().item()
                    q_values.append(q)
                y_batch.append(max(q_values) * GAMMA + r_batch[i])
        y_batch = torch.from_numpy(np.array(y_batch)).type(
            torch.FloatTensor).to(self.device)
        return q_batch, y_batch

    def update(self, horizon):
        self.lnets[horizon - 1].train()
        q_batch, y_batch = self.sample_from_buffer(BATCH_SIZE, horizon)
        loss = F.mse_loss(q_batch, y_batch)
        # set gnet grad = 0
        self.opts[horizon - 1].zero_grad()
        # compute lnet gradients
        loss.backward()
        for param in self.lnets[horizon - 1].parameters():
            param.grad.data.clamp_(-1, 1)


#         torch.nn.utils.clip_grad_norm_(self.l_agent.parameters(), MAX_GD_NORM)
# update gnet's grad with lnet's grad
        for lp, gp in zip(self.lnets[horizon - 1].parameters(),
                          self.gnets[horizon - 1].parameters()):
            if gp.grad is not None:  # if is not cleared (not zero_grad())
                return
            gp._grad = lp.grad
        # update gnet one step forward
        self.opts[horizon - 1].step()
        # pull gnet's parameters to local
        if self.mode == 'offline':
            self.lnets[horizon - 1].load_state_dict(self.gnets[horizon -
                                                               1].state_dict())

    def record(self, res_cost, res_explore, res_qvalue, res_reward,
               res_acc_test, res_acc_valid):
        with self.g_ep.get_lock():
            self.g_ep.value += 1
        res = (self.g_ep.value, res_cost, res_explore, res_qvalue, res_reward,
               res_acc_test, res_acc_valid)
        self.res_queue.put(res)
        # monitor
        with self.g_ep_r.get_lock():
            if self.g_ep_r.value == 0.:
                self.g_ep_r.value = res_reward[-1]
            else:
                self.g_ep_r.value = self.g_ep_r.value * 0.9 + res_reward[
                    -1] * 0.1
        print("*** {} {} complete ep {} | ep_r={}".format(
            self.device, self.pid, self.g_ep.value, self.g_ep_r.value))
예제 #7
0
def main():
    args = parser.parse_args()

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        device = torch.device("cuda:{}".format(args.cuda))
    else:
        device = torch.device("cpu")

    # important! Without this, lnet in worker cannot forward
    # spawn: for unix and linux; fork: for unix only
    mp.set_start_method('spawn')

    # === multiprocessing ====
    # global agent
    if args.model == 'ParamRNN':
        agent = ParamRNN(LabelEnv(args, None), args).to(device)
    elif args.model == 'ParamRNNBudget':
        agent = ParamRNNBudget(LabelEnv(args, None), args).to(device)
    elif args.model == 'TrellisCNN' or args.model == 'TrellisSupervised':
        agent = TrellisCNN(LabelEnv(args, None), args).to(device)
    elif args.model == 'TrellisBudget':
        agent = TrellisBudget(LabelEnv(args, None), args).to(device)
    elif args.model == 'PAL':
        agent = PAL(LabelEnv(args, None), args).to(device)
    elif args.model == 'SepRNN':
        agent = SepRNN(LabelEnv(args, None), args).to(device)
    elif args.model == 'Rand' or args.model == 'TE':
        agent = None
    else:
        print("agent model {} not implemented!!".format(args.model))
        return
    # optimizer for global model
    opt = SharedAdam(agent.parameters(), lr=0.001) if agent else None
    # share the global parameters
    if agent:
        agent.share_memory()
        para_size = sum(p.numel() for p in agent.parameters()
                        if p.requires_grad)
        print('global parameter size={}'.format(para_size))

    # offline train agent for args.episode_train rounds
    start_time = time.time()
    global_ep, global_ep_r, res_queue = mp.Value('i',
                                                 0), mp.Value('d',
                                                              0.), mp.Queue()
    if args.model == 'ParamRNN':
        tr_workers = [
            WorkerParam('offline', device, agent, opt, args, global_ep,
                        global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'ParamRNNBudget':
        tr_workers = [
            WorkerBudget('offline', device, agent, opt, args, global_ep,
                         global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'TrellisCNN' or args.model == 'PAL':
        tr_workers = [
            WorkerTrellis('offline', device, agent, opt, args, global_ep,
                          global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'TrellisSupervised':
        tr_workers = [
            WorkerSupervised('offline', device, agent, opt, args, global_ep,
                             global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'TrellisBudget':
        tr_workers = [
            WorkerTrellisBudget('offline', device, agent, opt, args, global_ep,
                                global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'SepRNN':
        tr_workers = [
            WorkerSep('offline', device, agent, opt, args, global_ep,
                      global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    tr_result = []
    if agent:
        [w.start() for w in tr_workers]
        while True:
            res = res_queue.get()
            if res is not None:
                tr_result.append(res)
            else:
                break
        [w.join() for w in tr_workers]
        print("Training Done! Cost {} for {} episodes.".format(
            time.time() - start_time, args.episode_train))

    # online test agent for args.episode_test rounds
    start_time = time.time()
    global_ep, global_ep_r, res_queue = mp.Value('i',
                                                 0), mp.Value('d',
                                                              0.), mp.Queue()
    if args.model == 'ParamRNN':
        ts_workers = [
            WorkerParam('online', device, agent, opt, args, global_ep,
                        global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'ParamRNNBudget':
        ts_workers = [
            WorkerBudget('online', device, agent, opt, args, global_ep,
                         global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'TrellisCNN' or args.model == 'PAL':
        ts_workers = [
            WorkerTrellis('online', device, agent, opt, args, global_ep,
                          global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'TrellisSupervised':
        ts_workers = [
            WorkerSupervised('online', device, agent, opt, args, global_ep,
                             global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'TrellisBudget':
        ts_workers = [
            WorkerTrellisBudget('offline', device, agent, opt, args, global_ep,
                                global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'SepRNN':
        ts_workers = [
            WorkerSep('online', device, agent, opt, args, global_ep,
                      global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]
    elif args.model == 'Rand' or args.model == 'TE':
        ts_workers = [
            WorkerHeur('online', device, agent, opt, args, global_ep,
                       global_ep_r, res_queue, pid)
            for pid in range(args.worker_n)
        ]

    ts_result = []
    [w.start() for w in ts_workers]
    while True:
        res = res_queue.get()
        if res is not None:
            ts_result.append(res)
        else:
            break
    [w.join() for w in ts_workers]
    print("Testing Done! Cost {} for {} episodes.".format(
        time.time() - start_time, args.episode_test))

    num = "num" if args.num_flag else ""
    emb = "embed" if args.embed_flag else ""
    filename = "./results_mp/" + args.data + num + emb + "_" + args.model + "_" \
                + str(args.budget) + "bgt_" + str(args.init) + "init_" \
                + str(args.episode_train) + "trainEp_" + str(args.episode_test) + "testEp"

    with open(filename + ".mp", "wb") as result:
        # format:
        # tr_result = [res]
        # res = (g_ep, cost, qvalue, r, acc_test, acc_valid)
        pickle.dump((tr_result, ts_result), result)