Exemple #1
0
        action = policy.act(state, total_game_step, isTrain = True).to(DEVICE) # sample an action
        next_state, reward, done, _ = env.step(action.item()) # take action in environment
        total_reward += reward
        reward = torch.tensor([reward]).float().to(DEVICE)
        
        if done: # whether this episode is terminate (game end)
            next_state = None
        else:
            next_state = torch.tensor([next_state]).float().to(DEVICE)
        
        replay_buffer.store(state, action, reward, next_state)
        state = next_state

        # optimze model with batch_size sample from buffer

        if replay_buffer.lenth() > BATCH_SIZE: # only optimize when replay buffer have sufficient number of data
            samples = replay_buffer.sample(BATCH_SIZE)
            samples = experience_sample(*zip(*samples))
            state_batch = torch.cat(samples.state)
            action_batch = torch.cat(samples.action)
            reward_batch = torch.cat(samples.reward)

            # get the Q-value Q(s(j), a(j))
            q_value_array = policy(state_batch) # get 4 value of all actions [V(a0), V(a1), V(a2), V(a3)]
            q_value = q_value_array.gather(1, action_batch)

            # set y(j) = r(j) --- if next_state(j+1) is terminal
            #            r(j) + r* Max(Q(S(j+1))) --- for non-terminal next_state(j+1)
            # Note : use Q-function of target_network

            terminal_mask = torch.tensor(tuple(map(lambda a: a is not None, samples.next_state)), device = DEVICE)