예제 #1
0
def train_rainbow_dqn_conv(Q_main, Q_target, replay_buffer, image_buffer,
                           batch_size, gamma):
    device = Q_main.device
    input_channel_size = Q_main.input_channel_size
    height = Q_main.height
    width = Q_main.width
    action_dim = Q_main.action_dim

    s_idx_batch, a_batch, r_batch, t_batch, _, indices, weights = replay_buffer.sample_batch(
        batch_size)

    assert s_idx_batch.shape == (batch_size, 1)

    s_batch = []
    s2_batch = []

    for s_idx in s_idx_batch:
        s_frame, s2_frame = image_buffer.get_state_and_next(int(s_idx))
        s_batch.append(s_frame)
        s2_batch.append(s2_frame)

    s_batch = torch.FloatTensor(np.array(s_batch)).to(device)
    s2_batch = torch.FloatTensor(np.array(s2_batch)).to(device)

    a_batch = torch.LongTensor(a_batch).to(device)  # (batch_size, action_dim)
    r_batch = torch.FloatTensor(r_batch).to(device)  # (batch_size, 1)
    t_batch = torch.FloatTensor(t_batch).to(device)  # (batch_size, 1)

    weights = torch.FloatTensor(weights).to(device).view(-1,
                                                         1)  #(batch_size, 1)

    assert s_batch.shape == (batch_size, input_channel_size, height, width)
    assert s2_batch.shape == (batch_size, input_channel_size, height, width)
    assert a_batch.shape == (batch_size, 1)
    assert r_batch.shape == (batch_size, 1)
    assert t_batch.shape == (batch_size, 1)
    assert weights.shape == (batch_size, 1)

    q1, q2 = Q_main.forward(s_batch)

    assert q1.shape == (batch_size, action_dim)
    assert q2.shape == (batch_size, action_dim)

    q1_action = q1.gather(1, a_batch)  # Q[s,a]
    q2_action = q2.gather(1, a_batch)

    assert q1_action.shape == (batch_size, 1)
    assert q2_action.shape == (batch_size, 1)

    # Q[s,a] = r + gamma * targetmin(mainargmax(Q[s',a']))

    with torch.no_grad():
        q1_s2_m, q2_s2_m = Q_main.forward(s2_batch)
        q1_a2_m = torch.argmax(q1_s2_m, dim=1, keepdim=True)
        q2_a2_m = torch.argmax(q2_s2_m, dim=1, keepdim=True)
        assert q1_a2_m.shape == (batch_size, 1)
        assert q2_a2_m.shape == (batch_size, 1)

        q1_s2_t, q2_s2_t = Q_target.forward(s2_batch)
        assert q1_s2_t.shape == (batch_size, action_dim)
        assert q2_s2_t.shape == (batch_size, action_dim)

        q1_max = q1_s2_t.gather(1, q1_a2_m)
        q2_max = q2_s2_t.gather(1, q2_a2_m)
        assert q1_max.shape == (batch_size, 1)
        assert q2_max.shape == (batch_size, 1)

        min_Q_max = torch.min(q1_max, q2_max)
        assert min_Q_max.shape == (batch_size, 1)

        y_q = r_batch + gamma * (1 - t_batch) * min_Q_max

    q1_l1 = torch.abs(q1_action - y_q)
    q2_l1 = torch.abs(q2_action - y_q)

    # PER Buffer Update
    with torch.no_grad():
        priority = (q1_l1 + q2_l1) / 2.0
        priority = priority.cpu().numpy()
        replay_buffer.update_priorities(indices, priority)
        replay_buffer.update_beta()

    q1_loss = torch.mean((q1_l1**2) * weights)
    q2_loss = torch.mean((q2_l1**2) * weights)

    q1_loss = torch.clamp(q1_loss, -1.0, 1.0)
    q2_loss = torch.clamp(q2_loss, -1.0, 1.0)

    Q_main.optimizer.zero_grad()
    q1_loss.backward()
    q2_loss.backward()
    torch.nn.utils.clip_grad_value_(Q_main.parameters(), 1.0)
    Q_main.optimizer.step()
    Q_main.step_lr_scheduler.step()

    soft_target_update(Q_main, Q_target, 0.005)

    with torch.no_grad():
        return torch.mean(q1), torch.min(q1), torch.max(q1), torch.mean(
            q2), torch.mean(r_batch)
예제 #2
0
def train_discrete_SAC(Q_main, Q_target, replay_buffer, batch_size, gamma):
    device = Q_main.device
    state_dim = Q_main.state_dim
    action_dim = Q_main.action_dim

    s_batch, a_batch, r_batch, t_batch, s2_batch = replay_buffer.sample_batch(
        batch_size)
    s_batch = torch.FloatTensor(s_batch).to(device)  # (batch_size, state_dim)
    a_batch = torch.LongTensor(a_batch).to(device)  # (batch_size, action_dim)
    r_batch = torch.FloatTensor(r_batch).to(device)  # (batch_size, 1)
    s2_batch = torch.FloatTensor(s2_batch).to(
        device)  # (batch_szie, state_dim)

    assert s_batch.shape == (batch_size, state_dim)
    assert a_batch.shape == (batch_size, 1)
    assert r_batch.shape == (batch_size, 1)
    assert s2_batch.shape == (batch_size, state_dim)

    q1, q2 = Q_main.forward(s_batch)

    assert q1.shape == (batch_size, action_dim)
    assert q2.shape == (batch_size, action_dim)

    # Q[s,a] = r + gamma * ( H[s'] + E_(a')[Q(s',a')] )

    # valid check!
    q1_action = q1.gather(1, a_batch)  # Q[s,a]
    q2_action = q2.gather(1, a_batch)

    assert q1_action.shape == (batch_size, 1)
    assert q2_action.shape == (batch_size, 1)

    with torch.no_grad():
        q1_t, q2_t = Q_target.forward(s2_batch)
        assert q1_t.shape == (batch_size, action_dim)
        assert q2_t.shape == (batch_size, action_dim)

        target_probs = Q_target.get_mean_distribution_from_Qs(q1_t, q2_t)
        assert target_probs.shape == (batch_size, action_dim)

        target_policy_distribution_s2 = Categorical(target_probs)
        target_entropy_s2 = target_policy_distribution_s2.entropy().view(
            -1, 1)  # H[s']
        assert target_entropy_s2.shape == (batch_size, 1)

        E_q1_s2_t = torch.sum(target_probs * q1_t, dim=1, keepdim=True)
        E_q2_s2_t = torch.sum(target_probs * q2_t, dim=1, keepdim=True)
        assert E_q1_s2_t.shape == (batch_size, 1)
        assert E_q2_s2_t.shape == (batch_size, 1)

        q_target_min = torch.min(E_q1_s2_t, E_q2_s2_t)
        assert q_target_min.shape == (batch_size, 1)

        y_v = target_entropy_s2 + q_target_min
        y_q = r_batch + gamma * y_v

    q1_loss = F.mse_loss(q1_action, y_q)
    q2_loss = F.mse_loss(q2_action, y_q)

    Q_main.optimizer.zero_grad()
    q1_loss.backward()
    q2_loss.backward()
    torch.nn.utils.clip_grad_value_(Q_main.parameters(), 1.0)
    Q_main.optimizer.step()

    soft_target_update(Q_main, Q_target, 0.001)

    with torch.no_grad():
        return torch.max(q1), torch.max(q2)
def train_triple_dqn(Q_main, Q_target, replay_buffer, image_buffer, batch_size,
                     gamma):
    device = Q_main.device
    input_channel_size = Q_main.input_channel_size
    height = Q_main.height
    width = Q_main.width
    action_dim = Q_main.action_dim

    s_idx_batch, a_batch, r_batch, t_batch, _ = replay_buffer.sample_batch(
        batch_size)

    assert s_idx_batch.shape == (batch_size, 1)

    s_batch = []
    s2_batch = []

    for s_idx in s_idx_batch:
        s_frame, s2_frame = image_buffer.get_state_and_next(int(s_idx))
        s_batch.append(s_frame)
        s2_batch.append(s2_frame)

    s_batch = torch.FloatTensor(np.array(s_batch)).to(device)
    s2_batch = torch.FloatTensor(np.array(s2_batch)).to(device)

    a_batch = torch.LongTensor(a_batch).to(device)  # (batch_size, action_dim)
    r_batch = torch.FloatTensor(r_batch).to(device)  # (batch_size, 1)

    assert s_batch.shape == (batch_size, input_channel_size, height, width)
    assert s2_batch.shape == (batch_size, input_channel_size, height, width)
    assert a_batch.shape == (batch_size, 1)
    assert r_batch.shape == (batch_size, 1)

    q1, q2 = Q_main.forward(s_batch)

    assert q1.shape == (batch_size, action_dim)
    assert q2.shape == (batch_size, action_dim)

    q1_action = q1.gather(1, a_batch)  # Q[s,a]
    q2_action = q2.gather(1, a_batch)

    assert q1_action.shape == (batch_size, 1)
    assert q2_action.shape == (batch_size, 1)

    # Q[s,a] = r + gamma * targetmin(max(Q[s',a']))

    with torch.no_grad():
        q1_s2_t, q2_s2_t = Q_target.forward(s2_batch)
        assert q1_s2_t.shape == (batch_size, action_dim)
        assert q2_s2_t.shape == (batch_size, action_dim)

        q1_max = torch.max(q1_s2_t, dim=1, keepdim=True)[0]
        q2_max = torch.max(q2_s2_t, dim=1, keepdim=True)[0]
        assert q1_max.shape == (batch_size, 1)
        assert q2_max.shape == (batch_size, 1)

        min_Q_max = torch.min(q1_max, q2_max)
        assert min_Q_max.shape == (batch_size, 1)

        y_q = r_batch + gamma * min_Q_max

    q1_loss = F.mse_loss(q1_action, y_q)
    q2_loss = F.mse_loss(q2_action, y_q)

    q1_loss = torch.clamp(q1_loss, -1.0, 1.0)
    q2_loss = torch.clamp(q2_loss, -1.0, 1.0)

    Q_main.optimizer.zero_grad()
    q1_loss.backward()
    q2_loss.backward()
    torch.nn.utils.clip_grad_value_(Q_main.parameters(), 1.0)
    Q_main.optimizer.step()
    Q_main.step_lr_scheduler.step()

    soft_target_update(Q_main, Q_target, 0.005)

    with torch.no_grad():
        return torch.mean(q1), torch.min(q1), torch.max(q1), torch.mean(
            q2), torch.mean(r_batch)


# #Frame buffer
# def train_triple_dqn(Q_main, Q_target, replay_buffer, batch_size, gamma):
#     device = Q_main.device
#     input_channel_size = Q_main.input_channel_size
#     height = Q_main.height
#     width = Q_main.width
#     action_dim = Q_main.action_dim
#
#     s_batch, a_batch, r_batch, t_batch, s2_batch = replay_buffer.sample_batch(batch_size)
#     s_batch = torch.FloatTensor(s_batch).to(device)  # (batch_size, input_channel_size, height, width)
#     a_batch = torch.LongTensor(a_batch).to(device)  # (batch_size, action_dim)
#     r_batch = torch.FloatTensor(r_batch).to(device)  # (batch_size, 1)
#     s2_batch = torch.FloatTensor(s2_batch).to(device)  # (batch_size, input_channel_size, height, width)
#
#     assert s_batch.shape == (batch_size, input_channel_size, height, width)
#     assert a_batch.shape == (batch_size, 1)
#     assert r_batch.shape == (batch_size, 1)
#     assert s2_batch.shape == (batch_size, input_channel_size, height, width)
#
#     q1, q2, q3 = Q_main.forward(s_batch)
#
#     assert q1.shape == (batch_size, action_dim)
#     assert q2.shape == (batch_size, action_dim)
#     assert q3.shape == (batch_size, action_dim)
#
#     q1_action = q1.gather(1, a_batch)  # Q[s,a]
#     q2_action = q2.gather(1, a_batch)
#     q3_action = q3.gather(1, a_batch)
#
#     assert q1_action.shape == (batch_size, 1)
#     assert q2_action.shape == (batch_size, 1)
#     assert q3_action.shape == (batch_size, 1)
#
#     # Q[s,a] = r + gamma * targetmin(max(Q[s',a']))
#
#     with torch.no_grad():
#         q1_s2_t, q2_s2_t, q3_s2_t = Q_target.forward(s2_batch)
#         assert q1_s2_t.shape == (batch_size, action_dim)
#         assert q2_s2_t.shape == (batch_size, action_dim)
#         assert q3_s2_t.shape == (batch_size, action_dim)
#
#         q1_max = torch.max(q1_s2_t, dim=1, keepdim=True)[0]
#         q2_max = torch.max(q2_s2_t, dim=1, keepdim=True)[0]
#         q3_max = torch.max(q3_s2_t, dim=1, keepdim=True)[0]
#         assert q1_max.shape == (batch_size, 1)
#         assert q2_max.shape == (batch_size, 1)
#         assert q3_max.shape == (batch_size, 1)
#
#         min_Q_max = torch.min(torch.min(q1_max, q2_max),q3_max)
#         assert min_Q_max.shape == (batch_size, 1)
#
#         y_q = r_batch + gamma * min_Q_max
#
#     q1_loss = F.mse_loss(q1_action, y_q)
#     q2_loss = F.mse_loss(q2_action, y_q)
#     q3_loss = F.mse_loss(q3_action, y_q)
#
#     Q_main.optimizer.zero_grad()
#     q1_loss.backward()
#     q2_loss.backward()
#     q3_loss.backward()
#     torch.nn.utils.clip_grad_value_(Q_main.parameters(), 1.0)
#     Q_main.optimizer.step()
#
#     soft_target_update(Q_main, Q_target, 0.005)
#
#     with torch.no_grad():
#         return torch.max(q1), torch.max(q2), torch.max(q3), np.mean(entropy(q1.detach().cpu().numpy().T)),torch.mean(r_batch)
예제 #4
0
def train_discrete_Conv_SAC_max(Q_main, Q_target, replay_buffer, batch_size,
                                gamma, alpha):
    device = Q_main.device
    input_channel_size = Q_main.input_channel_size
    height = Q_main.height
    width = Q_main.width
    action_dim = Q_main.action_dim

    s_batch, a_batch, r_batch, t_batch, s2_batch = replay_buffer.sample_batch(
        batch_size)
    s_batch = torch.FloatTensor(s_batch).to(
        device)  # (batch_size, input_channel_size, height, width)
    a_batch = torch.LongTensor(a_batch).to(device)  # (batch_size, action_dim)
    r_batch = torch.FloatTensor(r_batch).to(device)  # (batch_size, 1)
    s2_batch = torch.FloatTensor(s2_batch).to(
        device)  # (batch_size, input_channel_size, height, width)

    assert s_batch.shape == (batch_size, input_channel_size, height, width)
    assert a_batch.shape == (batch_size, 1)
    assert r_batch.shape == (batch_size, 1)
    assert s2_batch.shape == (batch_size, input_channel_size, height, width)

    q1, q2 = Q_main.forward(s_batch)

    assert q1.shape == (batch_size, action_dim)
    assert q2.shape == (batch_size, action_dim)

    # Q[s,a] = r + gamma * ( H[s'] + max[Q(s',a')] )

    q1_action = q1.gather(1, a_batch)  # Q[s,a]
    q2_action = q2.gather(1, a_batch)

    assert q1_action.shape == (batch_size, 1)
    assert q2_action.shape == (batch_size, 1)

    with torch.no_grad():
        q1_t, q2_t = Q_target.forward(s2_batch)
        assert q1_t.shape == (batch_size, action_dim)
        assert q2_t.shape == (batch_size, action_dim)

        q1_m, q2_m = Q_main.forward(s2_batch)
        assert q1_m.shape == (batch_size, action_dim)
        assert q2_m.shape == (batch_size, action_dim)

        main_probs = Q_main.get_mean_distribution_from_Qs(q1_m, q2_m)
        assert main_probs.shape == (batch_size, action_dim)

        main_policy_distribution_s2 = Categorical(main_probs)
        main_entropy_s2 = main_policy_distribution_s2.entropy().view(
            -1, 1)  # H[s']
        assert main_entropy_s2.shape == (batch_size, 1)

        E_q1_s2_t = torch.max(q1_t, dim=1, keepdim=True)[0]
        E_q2_s2_t = torch.max(q2_t, dim=1, keepdim=True)[0]
        assert E_q1_s2_t.shape == (batch_size, 1)
        assert E_q2_s2_t.shape == (batch_size, 1)

        q_target_min = torch.min(E_q1_s2_t, E_q2_s2_t)
        assert q_target_min.shape == (batch_size, 1)

        y_v = alpha * main_entropy_s2 + q_target_min
        y_q = r_batch + gamma * y_v

    q1_loss = F.mse_loss(q1_action, y_q)
    q2_loss = F.mse_loss(q2_action, y_q)

    Q_main.optimizer.zero_grad()
    q1_loss.backward()
    q2_loss.backward()
    torch.nn.utils.clip_grad_value_(Q_main.parameters(), 1.0)
    Q_main.optimizer.step()

    soft_target_update(Q_main, Q_target, 0.005)

    with torch.no_grad():
        return torch.max(q1), torch.max(q2), torch.mean(
            main_entropy_s2), torch.mean(r_batch)