Beispiel #1
0
    def __init__(self,
                 env_id,
                 seed=0,
                 lr=1e-5,
                 n_step=3,
                 gamma=0.99,
                 n_workers=20,
                 max_norm=40,
                 target_update_interval=2500,
                 save_interval=5000,
                 batch_size=64,
                 buffer_size=1e6,
                 prior_alpha=0.6,
                 prior_beta=0.4,
                 publish_param_interval=32,
                 max_step=1e5):
        self.env = gym.make(env_id)
        self.seed = seed
        self.lr = lr
        self.n_step = n_step
        self.gamma = gamma
        self.max_norm = max_norm
        self.target_update_interval = target_update_interval
        self.save_interval = save_interval
        self.publish_param_interval = publish_param_interval
        self.batch_size = batch_size
        self.prior_beta = prior_beta
        self.max_step = max_step

        self.buffer = CustomPrioritizedReplayBuffer(size=buffer_size,
                                                    alpha=prior_alpha)
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = DuelingDQN(self.env).to(self.device)
        self.tgt_model = DuelingDQN(self.env).to(self.device)
        self.tgt_model.load_state_dict(self.model.state_dict())
        self.optimizer = torch.optim.RMSprop(self.model.parameters(),
                                             self.lr,
                                             alpha=0.95,
                                             eps=1.5e-7,
                                             centered=True)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                         step_size=1000,
                                                         gamma=0.99)
        self.beta_by_frame = lambda frame_idx: min(
            1.0, self.prior_beta + frame_idx * (1.0 - self.prior_beta) / 1000)
        self.batch_recorder = BatchRecorder(env_id=env_id,
                                            env_seed=seed,
                                            n_workers=n_workers,
                                            buffer=self.buffer,
                                            n_steps=n_step,
                                            gamma=gamma,
                                            max_episode_length=50000)
        self.writer = SummaryWriter(
            comment="-{}-learner".format(self.env.unwrapped.spec.id))
Beispiel #2
0
async def main():
    """
    main event loop
    """
    args = argparser()
    utils.set_global_seeds(args.seed, use_torch=False)

    procs = [
        Process(target=recv_batch_device),
        Process(target=recv_prios_device),
        Process(target=send_batch_device),
    ]
    for p in procs:
        p.start()

    buffer = CustomPrioritizedReplayBuffer(args.replay_buffer_size, args.alpha)
    exe = ThreadPoolExecutor()
    event = asyncio.Event()
    lock = asyncio.Lock()

    # TODO: How to decide the proper number of asyncio workers?
    workers = []
    for _ in range(args.n_recv_batch_worker):
        w = recv_batch_worker(buffer, exe, event, lock, args.threshold_size)
        workers.append(w)
    for _ in range(args.n_recv_prios_worker):
        w = recv_prios_worker(buffer, exe, event, lock)
        workers.append(w)
    for _ in range(args.n_send_batch_worker):
        w = send_batch_worker(buffer, exe, event, lock, args.batch_size, args.beta)
        workers.append(w)

    await asyncio.gather(*workers)
    return True
Beispiel #3
0
def replay(args):
    comm = global_dict['comm_local']
    prev_t = time.time()
    global push_size, sample_size
    writer = SummaryWriter(log_dir=os.path.join(
        args['log_dir'], f'{global_dict["unit_idx"]}-replay'))
    tb_step = 0

    buffer = CustomPrioritizedReplayBuffer(args['replay_buffer_size'],
                                           args['alpha'])
    bufs = [
        [bytearray(50 * 1024 * 1024) for _ in range(3)],
        bytearray(1),
        bytearray(10 * 1024),
    ]
    _n = 0
    requests = [
        comm.Irecv(buf=bufs[0][_n],
                   source=MPI.ANY_SOURCE,
                   tag=utils.TAG_RECV_BATCH),
        comm.Irecv(buf=bufs[1],
                   source=global_dict['rank_learner'],
                   tag=utils.TAG_SEND_BATCH),
        comm.Irecv(buf=bufs[2],
                   source=global_dict['rank_learner'],
                   tag=utils.TAG_RECV_PRIOS),
    ]

    lock = threading.Lock()
    thread_pool_executor = ThreadPoolExecutor(max_workers=10)
    start_sending_batch_condition = threading.Condition()
    pending_count_record = []

    while True:
        index = Request.Waitany(requests)
        thread_pool_executor.submit(worker, args, index,
                                    bufs[index] if index != 0 else bufs[0][_n],
                                    lock, buffer,
                                    start_sending_batch_condition)
        if index == 0:  # batch recv
            _n = (_n + 1) % 3
            requests[0] = comm.Irecv(buf=bufs[0][_n],
                                     source=MPI.ANY_SOURCE,
                                     tag=utils.TAG_RECV_BATCH)
        elif index == 1:  # batch send
            requests[1] = comm.Irecv(buf=bufs[1],
                                     source=global_dict['rank_learner'],
                                     tag=utils.TAG_SEND_BATCH)
        elif index == 2:  # prios recv
            requests[2] = comm.Irecv(buf=bufs[2],
                                     source=global_dict['rank_learner'],
                                     tag=utils.TAG_RECV_PRIOS)
        pending_count_record.append(thread_pool_executor._work_queue.qsize())

        delta_t = time.time() - prev_t
        if delta_t > 60:
            tb_step += 1
            writer.add_scalar('replay/1_push_per_second', push_size / delta_t,
                              tb_step)
            writer.add_scalar('replay/2_sample_per_second',
                              sample_size / delta_t, tb_step)
            writer.add_scalar('replay/3_buffer_size', len(buffer), tb_step)
            writer.add_scalar('replay/4_pending_count',
                              np.mean(pending_count_record), tb_step)
            writer.add_scalar('replay/5_priorities_sum', buffer._it_sum.sum(),
                              tb_step)
            sample_size = 0
            push_size = 0
            prev_t = time.time()
            pending_count_record.clear()
Beispiel #4
0
                                                data_ids=data_ids,
                                                weights=weights,
                                                idxes=idxes,
                                                timestamp=timestamps)


if __name__ == '__main__':
    """
    environment parameters
    """
    n_actors, replay_ip, learner_ip, registerActorPort, sendBatchPrioriPort, updatePrioriPort, sampleDataPort, cacheUpdatePort = get_environ(
    )

    args = argparser()
    utils.set_global_seeds(args.seed, use_torch=False)
    buffer = CustomPrioritizedReplayBuffer(args.replay_buffer_size, args.alpha,
                                           n_actors)
    local_buffer_size = args.replay_buffer_size // n_actors
    event = Event()
    lock = Lock()

    #conn = grpc.insecure_channel(learner_ip + ':' + cacheUpdatePort)
    #cache_data_client = apex_data_pb2_grpc.CacheUpdateStub(channel=conn)
    """
    actor send (actor_id, data_id, priori) to replay buffer
    """
    #sendBatchPrioriPort = '8080'
    sendBatchPrioriServer = grpc.server(ThreadPoolExecutor(max_workers=4))
    apex_data_pb2_grpc.add_SendBatchPrioriServicer_to_server(
        SendBatchPriori(), sendBatchPrioriServer)
    sendBatchPrioriServer.add_insecure_port(replay_ip + ':' +
                                            sendBatchPrioriPort)
Beispiel #5
0
class train_DQN():
    def __init__(self,
                 env_id,
                 seed=0,
                 lr=1e-5,
                 n_step=3,
                 gamma=0.99,
                 n_workers=20,
                 max_norm=40,
                 target_update_interval=2500,
                 save_interval=5000,
                 batch_size=64,
                 buffer_size=1e6,
                 prior_alpha=0.6,
                 prior_beta=0.4,
                 publish_param_interval=32,
                 max_step=1e5):
        self.env = gym.make(env_id)
        self.seed = seed
        self.lr = lr
        self.n_step = n_step
        self.gamma = gamma
        self.max_norm = max_norm
        self.target_update_interval = target_update_interval
        self.save_interval = save_interval
        self.publish_param_interval = publish_param_interval
        self.batch_size = batch_size
        self.prior_beta = prior_beta
        self.max_step = max_step

        self.buffer = CustomPrioritizedReplayBuffer(size=buffer_size,
                                                    alpha=prior_alpha)
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = DuelingDQN(self.env).to(self.device)
        self.tgt_model = DuelingDQN(self.env).to(self.device)
        self.tgt_model.load_state_dict(self.model.state_dict())
        self.optimizer = torch.optim.RMSprop(self.model.parameters(),
                                             self.lr,
                                             alpha=0.95,
                                             eps=1.5e-7,
                                             centered=True)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                         step_size=1000,
                                                         gamma=0.99)
        self.beta_by_frame = lambda frame_idx: min(
            1.0, self.prior_beta + frame_idx * (1.0 - self.prior_beta) / 1000)
        self.batch_recorder = BatchRecorder(env_id=env_id,
                                            env_seed=seed,
                                            n_workers=n_workers,
                                            buffer=self.buffer,
                                            n_steps=n_step,
                                            gamma=gamma,
                                            max_episode_length=50000)
        self.writer = SummaryWriter(
            comment="-{}-learner".format(self.env.unwrapped.spec.id))

    def train(self):
        utils.set_global_seeds(self.seed, use_torch=True)

        learn_idx = 0
        while True:
            beta = self.beta_by_frame(learn_idx)
            states, actions, rewards, next_states, dones, weights, idxes = self.buffer.sample(
                self.batch_size, beta)
            states = torch.FloatTensor(states).to(self.device)
            actions = torch.LongTensor(actions).to(self.device)
            rewards = torch.FloatTensor(rewards).to(self.device)
            next_states = torch.FloatTensor(next_states).to(self.device)
            dones = torch.FloatTensor(dones).to(self.device)
            weights = torch.FloatTensor(weights).to(self.device)
            batch = (states, actions, rewards, next_states, dones, weights)

            loss, prios = utils.compute_loss(self.model, self.tgt_model, batch,
                                             self.n_step, self.gamma)

            self.scheduler.step()
            grad_norm = utils.update_parameters(loss, self.model,
                                                self.optimizer, self.max_norm)

            self.buffer.update_priorities(idxes, prios)

            batch, idxes, prios = None, None, None
            learn_idx += 1

            self.writer.add_scalar("learner/loss", loss, learn_idx)
            self.writer.add_scalar("learner/grad_norm", grad_norm, learn_idx)

            if learn_idx % self.target_update_interval == 0:
                print("Updating Target Network..")
                self.tgt_model.load_state_dict(self.model.state_dict())
            if learn_idx % self.save_interval == 0:
                print("Saving Model..")
                torch.save(self.model.state_dict(),
                           "model{}.pth".format(learn_idx))
            if learn_idx % self.publish_param_interval == 0:
                self.batch_recorder.set_worker_weights(
                    copy.deepcopy(self.model))
            if learn_idx >= self.max_step:
                torch.save(self.model.state_dict(),
                           "model{}.pth".format(learn_idx))
                self.batch_recorder.cleanup()
                break

    def load_model(self, idx):
        with open("model{}.pth".format(idx), "rb") as f:
            print("loading weights_{}".format(idx))
            self.model.load_state_dict(torch.load(f, map_location="cpu"))

    def sampling_data(self):
        self.batch_recorder.record_batch()