action = policy.act(state, total_game_step, isTrain = True).to(DEVICE) # sample an action next_state, reward, done, _ = env.step(action.item()) # take action in environment total_reward += reward reward = torch.tensor([reward]).float().to(DEVICE) if done: # whether this episode is terminate (game end) next_state = None else: next_state = torch.tensor([next_state]).float().to(DEVICE) replay_buffer.store(state, action, reward, next_state) state = next_state # optimze model with batch_size sample from buffer if replay_buffer.lenth() > BATCH_SIZE: # only optimize when replay buffer have sufficient number of data samples = replay_buffer.sample(BATCH_SIZE) samples = experience_sample(*zip(*samples)) state_batch = torch.cat(samples.state) action_batch = torch.cat(samples.action) reward_batch = torch.cat(samples.reward) # get the Q-value Q(s(j), a(j)) q_value_array = policy(state_batch) # get 4 value of all actions [V(a0), V(a1), V(a2), V(a3)] q_value = q_value_array.gather(1, action_batch) # set y(j) = r(j) --- if next_state(j+1) is terminal # r(j) + r* Max(Q(S(j+1))) --- for non-terminal next_state(j+1) # Note : use Q-function of target_network terminal_mask = torch.tensor(tuple(map(lambda a: a is not None, samples.next_state)), device = DEVICE)