def train_rainbow_dqn_conv(Q_main, Q_target, replay_buffer, image_buffer, batch_size, gamma): device = Q_main.device input_channel_size = Q_main.input_channel_size height = Q_main.height width = Q_main.width action_dim = Q_main.action_dim s_idx_batch, a_batch, r_batch, t_batch, _, indices, weights = replay_buffer.sample_batch( batch_size) assert s_idx_batch.shape == (batch_size, 1) s_batch = [] s2_batch = [] for s_idx in s_idx_batch: s_frame, s2_frame = image_buffer.get_state_and_next(int(s_idx)) s_batch.append(s_frame) s2_batch.append(s2_frame) s_batch = torch.FloatTensor(np.array(s_batch)).to(device) s2_batch = torch.FloatTensor(np.array(s2_batch)).to(device) a_batch = torch.LongTensor(a_batch).to(device) # (batch_size, action_dim) r_batch = torch.FloatTensor(r_batch).to(device) # (batch_size, 1) t_batch = torch.FloatTensor(t_batch).to(device) # (batch_size, 1) weights = torch.FloatTensor(weights).to(device).view(-1, 1) #(batch_size, 1) assert s_batch.shape == (batch_size, input_channel_size, height, width) assert s2_batch.shape == (batch_size, input_channel_size, height, width) assert a_batch.shape == (batch_size, 1) assert r_batch.shape == (batch_size, 1) assert t_batch.shape == (batch_size, 1) assert weights.shape == (batch_size, 1) q1, q2 = Q_main.forward(s_batch) assert q1.shape == (batch_size, action_dim) assert q2.shape == (batch_size, action_dim) q1_action = q1.gather(1, a_batch) # Q[s,a] q2_action = q2.gather(1, a_batch) assert q1_action.shape == (batch_size, 1) assert q2_action.shape == (batch_size, 1) # Q[s,a] = r + gamma * targetmin(mainargmax(Q[s',a'])) with torch.no_grad(): q1_s2_m, q2_s2_m = Q_main.forward(s2_batch) q1_a2_m = torch.argmax(q1_s2_m, dim=1, keepdim=True) q2_a2_m = torch.argmax(q2_s2_m, dim=1, keepdim=True) assert q1_a2_m.shape == (batch_size, 1) assert q2_a2_m.shape == (batch_size, 1) q1_s2_t, q2_s2_t = Q_target.forward(s2_batch) assert q1_s2_t.shape == (batch_size, action_dim) assert q2_s2_t.shape == (batch_size, action_dim) q1_max = q1_s2_t.gather(1, q1_a2_m) q2_max = q2_s2_t.gather(1, q2_a2_m) assert q1_max.shape == (batch_size, 1) assert q2_max.shape == (batch_size, 1) min_Q_max = torch.min(q1_max, q2_max) assert min_Q_max.shape == (batch_size, 1) y_q = r_batch + gamma * (1 - t_batch) * min_Q_max q1_l1 = torch.abs(q1_action - y_q) q2_l1 = torch.abs(q2_action - y_q) # PER Buffer Update with torch.no_grad(): priority = (q1_l1 + q2_l1) / 2.0 priority = priority.cpu().numpy() replay_buffer.update_priorities(indices, priority) replay_buffer.update_beta() q1_loss = torch.mean((q1_l1**2) * weights) q2_loss = torch.mean((q2_l1**2) * weights) q1_loss = torch.clamp(q1_loss, -1.0, 1.0) q2_loss = torch.clamp(q2_loss, -1.0, 1.0) Q_main.optimizer.zero_grad() q1_loss.backward() q2_loss.backward() torch.nn.utils.clip_grad_value_(Q_main.parameters(), 1.0) Q_main.optimizer.step() Q_main.step_lr_scheduler.step() soft_target_update(Q_main, Q_target, 0.005) with torch.no_grad(): return torch.mean(q1), torch.min(q1), torch.max(q1), torch.mean( q2), torch.mean(r_batch)
def train_discrete_SAC(Q_main, Q_target, replay_buffer, batch_size, gamma): device = Q_main.device state_dim = Q_main.state_dim action_dim = Q_main.action_dim s_batch, a_batch, r_batch, t_batch, s2_batch = replay_buffer.sample_batch( batch_size) s_batch = torch.FloatTensor(s_batch).to(device) # (batch_size, state_dim) a_batch = torch.LongTensor(a_batch).to(device) # (batch_size, action_dim) r_batch = torch.FloatTensor(r_batch).to(device) # (batch_size, 1) s2_batch = torch.FloatTensor(s2_batch).to( device) # (batch_szie, state_dim) assert s_batch.shape == (batch_size, state_dim) assert a_batch.shape == (batch_size, 1) assert r_batch.shape == (batch_size, 1) assert s2_batch.shape == (batch_size, state_dim) q1, q2 = Q_main.forward(s_batch) assert q1.shape == (batch_size, action_dim) assert q2.shape == (batch_size, action_dim) # Q[s,a] = r + gamma * ( H[s'] + E_(a')[Q(s',a')] ) # valid check! q1_action = q1.gather(1, a_batch) # Q[s,a] q2_action = q2.gather(1, a_batch) assert q1_action.shape == (batch_size, 1) assert q2_action.shape == (batch_size, 1) with torch.no_grad(): q1_t, q2_t = Q_target.forward(s2_batch) assert q1_t.shape == (batch_size, action_dim) assert q2_t.shape == (batch_size, action_dim) target_probs = Q_target.get_mean_distribution_from_Qs(q1_t, q2_t) assert target_probs.shape == (batch_size, action_dim) target_policy_distribution_s2 = Categorical(target_probs) target_entropy_s2 = target_policy_distribution_s2.entropy().view( -1, 1) # H[s'] assert target_entropy_s2.shape == (batch_size, 1) E_q1_s2_t = torch.sum(target_probs * q1_t, dim=1, keepdim=True) E_q2_s2_t = torch.sum(target_probs * q2_t, dim=1, keepdim=True) assert E_q1_s2_t.shape == (batch_size, 1) assert E_q2_s2_t.shape == (batch_size, 1) q_target_min = torch.min(E_q1_s2_t, E_q2_s2_t) assert q_target_min.shape == (batch_size, 1) y_v = target_entropy_s2 + q_target_min y_q = r_batch + gamma * y_v q1_loss = F.mse_loss(q1_action, y_q) q2_loss = F.mse_loss(q2_action, y_q) Q_main.optimizer.zero_grad() q1_loss.backward() q2_loss.backward() torch.nn.utils.clip_grad_value_(Q_main.parameters(), 1.0) Q_main.optimizer.step() soft_target_update(Q_main, Q_target, 0.001) with torch.no_grad(): return torch.max(q1), torch.max(q2)
def train_triple_dqn(Q_main, Q_target, replay_buffer, image_buffer, batch_size, gamma): device = Q_main.device input_channel_size = Q_main.input_channel_size height = Q_main.height width = Q_main.width action_dim = Q_main.action_dim s_idx_batch, a_batch, r_batch, t_batch, _ = replay_buffer.sample_batch( batch_size) assert s_idx_batch.shape == (batch_size, 1) s_batch = [] s2_batch = [] for s_idx in s_idx_batch: s_frame, s2_frame = image_buffer.get_state_and_next(int(s_idx)) s_batch.append(s_frame) s2_batch.append(s2_frame) s_batch = torch.FloatTensor(np.array(s_batch)).to(device) s2_batch = torch.FloatTensor(np.array(s2_batch)).to(device) a_batch = torch.LongTensor(a_batch).to(device) # (batch_size, action_dim) r_batch = torch.FloatTensor(r_batch).to(device) # (batch_size, 1) assert s_batch.shape == (batch_size, input_channel_size, height, width) assert s2_batch.shape == (batch_size, input_channel_size, height, width) assert a_batch.shape == (batch_size, 1) assert r_batch.shape == (batch_size, 1) q1, q2 = Q_main.forward(s_batch) assert q1.shape == (batch_size, action_dim) assert q2.shape == (batch_size, action_dim) q1_action = q1.gather(1, a_batch) # Q[s,a] q2_action = q2.gather(1, a_batch) assert q1_action.shape == (batch_size, 1) assert q2_action.shape == (batch_size, 1) # Q[s,a] = r + gamma * targetmin(max(Q[s',a'])) with torch.no_grad(): q1_s2_t, q2_s2_t = Q_target.forward(s2_batch) assert q1_s2_t.shape == (batch_size, action_dim) assert q2_s2_t.shape == (batch_size, action_dim) q1_max = torch.max(q1_s2_t, dim=1, keepdim=True)[0] q2_max = torch.max(q2_s2_t, dim=1, keepdim=True)[0] assert q1_max.shape == (batch_size, 1) assert q2_max.shape == (batch_size, 1) min_Q_max = torch.min(q1_max, q2_max) assert min_Q_max.shape == (batch_size, 1) y_q = r_batch + gamma * min_Q_max q1_loss = F.mse_loss(q1_action, y_q) q2_loss = F.mse_loss(q2_action, y_q) q1_loss = torch.clamp(q1_loss, -1.0, 1.0) q2_loss = torch.clamp(q2_loss, -1.0, 1.0) Q_main.optimizer.zero_grad() q1_loss.backward() q2_loss.backward() torch.nn.utils.clip_grad_value_(Q_main.parameters(), 1.0) Q_main.optimizer.step() Q_main.step_lr_scheduler.step() soft_target_update(Q_main, Q_target, 0.005) with torch.no_grad(): return torch.mean(q1), torch.min(q1), torch.max(q1), torch.mean( q2), torch.mean(r_batch) # #Frame buffer # def train_triple_dqn(Q_main, Q_target, replay_buffer, batch_size, gamma): # device = Q_main.device # input_channel_size = Q_main.input_channel_size # height = Q_main.height # width = Q_main.width # action_dim = Q_main.action_dim # # s_batch, a_batch, r_batch, t_batch, s2_batch = replay_buffer.sample_batch(batch_size) # s_batch = torch.FloatTensor(s_batch).to(device) # (batch_size, input_channel_size, height, width) # a_batch = torch.LongTensor(a_batch).to(device) # (batch_size, action_dim) # r_batch = torch.FloatTensor(r_batch).to(device) # (batch_size, 1) # s2_batch = torch.FloatTensor(s2_batch).to(device) # (batch_size, input_channel_size, height, width) # # assert s_batch.shape == (batch_size, input_channel_size, height, width) # assert a_batch.shape == (batch_size, 1) # assert r_batch.shape == (batch_size, 1) # assert s2_batch.shape == (batch_size, input_channel_size, height, width) # # q1, q2, q3 = Q_main.forward(s_batch) # # assert q1.shape == (batch_size, action_dim) # assert q2.shape == (batch_size, action_dim) # assert q3.shape == (batch_size, action_dim) # # q1_action = q1.gather(1, a_batch) # Q[s,a] # q2_action = q2.gather(1, a_batch) # q3_action = q3.gather(1, a_batch) # # assert q1_action.shape == (batch_size, 1) # assert q2_action.shape == (batch_size, 1) # assert q3_action.shape == (batch_size, 1) # # # Q[s,a] = r + gamma * targetmin(max(Q[s',a'])) # # with torch.no_grad(): # q1_s2_t, q2_s2_t, q3_s2_t = Q_target.forward(s2_batch) # assert q1_s2_t.shape == (batch_size, action_dim) # assert q2_s2_t.shape == (batch_size, action_dim) # assert q3_s2_t.shape == (batch_size, action_dim) # # q1_max = torch.max(q1_s2_t, dim=1, keepdim=True)[0] # q2_max = torch.max(q2_s2_t, dim=1, keepdim=True)[0] # q3_max = torch.max(q3_s2_t, dim=1, keepdim=True)[0] # assert q1_max.shape == (batch_size, 1) # assert q2_max.shape == (batch_size, 1) # assert q3_max.shape == (batch_size, 1) # # min_Q_max = torch.min(torch.min(q1_max, q2_max),q3_max) # assert min_Q_max.shape == (batch_size, 1) # # y_q = r_batch + gamma * min_Q_max # # q1_loss = F.mse_loss(q1_action, y_q) # q2_loss = F.mse_loss(q2_action, y_q) # q3_loss = F.mse_loss(q3_action, y_q) # # Q_main.optimizer.zero_grad() # q1_loss.backward() # q2_loss.backward() # q3_loss.backward() # torch.nn.utils.clip_grad_value_(Q_main.parameters(), 1.0) # Q_main.optimizer.step() # # soft_target_update(Q_main, Q_target, 0.005) # # with torch.no_grad(): # return torch.max(q1), torch.max(q2), torch.max(q3), np.mean(entropy(q1.detach().cpu().numpy().T)),torch.mean(r_batch)
def train_discrete_Conv_SAC_max(Q_main, Q_target, replay_buffer, batch_size, gamma, alpha): device = Q_main.device input_channel_size = Q_main.input_channel_size height = Q_main.height width = Q_main.width action_dim = Q_main.action_dim s_batch, a_batch, r_batch, t_batch, s2_batch = replay_buffer.sample_batch( batch_size) s_batch = torch.FloatTensor(s_batch).to( device) # (batch_size, input_channel_size, height, width) a_batch = torch.LongTensor(a_batch).to(device) # (batch_size, action_dim) r_batch = torch.FloatTensor(r_batch).to(device) # (batch_size, 1) s2_batch = torch.FloatTensor(s2_batch).to( device) # (batch_size, input_channel_size, height, width) assert s_batch.shape == (batch_size, input_channel_size, height, width) assert a_batch.shape == (batch_size, 1) assert r_batch.shape == (batch_size, 1) assert s2_batch.shape == (batch_size, input_channel_size, height, width) q1, q2 = Q_main.forward(s_batch) assert q1.shape == (batch_size, action_dim) assert q2.shape == (batch_size, action_dim) # Q[s,a] = r + gamma * ( H[s'] + max[Q(s',a')] ) q1_action = q1.gather(1, a_batch) # Q[s,a] q2_action = q2.gather(1, a_batch) assert q1_action.shape == (batch_size, 1) assert q2_action.shape == (batch_size, 1) with torch.no_grad(): q1_t, q2_t = Q_target.forward(s2_batch) assert q1_t.shape == (batch_size, action_dim) assert q2_t.shape == (batch_size, action_dim) q1_m, q2_m = Q_main.forward(s2_batch) assert q1_m.shape == (batch_size, action_dim) assert q2_m.shape == (batch_size, action_dim) main_probs = Q_main.get_mean_distribution_from_Qs(q1_m, q2_m) assert main_probs.shape == (batch_size, action_dim) main_policy_distribution_s2 = Categorical(main_probs) main_entropy_s2 = main_policy_distribution_s2.entropy().view( -1, 1) # H[s'] assert main_entropy_s2.shape == (batch_size, 1) E_q1_s2_t = torch.max(q1_t, dim=1, keepdim=True)[0] E_q2_s2_t = torch.max(q2_t, dim=1, keepdim=True)[0] assert E_q1_s2_t.shape == (batch_size, 1) assert E_q2_s2_t.shape == (batch_size, 1) q_target_min = torch.min(E_q1_s2_t, E_q2_s2_t) assert q_target_min.shape == (batch_size, 1) y_v = alpha * main_entropy_s2 + q_target_min y_q = r_batch + gamma * y_v q1_loss = F.mse_loss(q1_action, y_q) q2_loss = F.mse_loss(q2_action, y_q) Q_main.optimizer.zero_grad() q1_loss.backward() q2_loss.backward() torch.nn.utils.clip_grad_value_(Q_main.parameters(), 1.0) Q_main.optimizer.step() soft_target_update(Q_main, Q_target, 0.005) with torch.no_grad(): return torch.max(q1), torch.max(q2), torch.mean( main_entropy_s2), torch.mean(r_batch)