def compute_rainbow_q_values(policy, model, obs, explore, is_training=False): """ supports normal DQN and distributional DQN now reference: https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/dqn_tf_policy.py """ # # NOTE: LAZY DEV # if policy.config["num_atoms"] > 1: # raise ValueError("torch DQN does not support distributional DQN yet!") config = policy.config # model_out, state = model({ # SampleBatch.CUR_OBS: obs, # "is_training": is_training, # }, [], None) # NOTE: set noise with training flag if is_training: model.train() else: model.eval() model_out, state = model.get_embeddings({ SampleBatch.CUR_OBS: obs, "is_training": is_training, }, [], None) out = model.get_advantages_or_q_values(model_out) if config["num_atoms"] > 1: (advantages_or_q_values, z, support_logits_per_action, logits, dist) = out else: (advantages_or_q_values, logits, dist) = out if policy.config["dueling"]: state_value = model.get_state_value(model_out) if config["num_atoms"] > 1: support_logits_per_action_mean = torch.mean(support_logits_per_action, 1) support_logits_per_action_centered = support_logits_per_action - torch.unsqueeze( support_logits_per_action_mean, 1) support_logits_per_action = torch.unsqueeze( state_value, 1) + support_logits_per_action_centered support_prob_per_action = F.softmax(support_logits_per_action) q_values = torch.sum(z * support_prob_per_action, dim=-1) logits = support_logits_per_action dist = support_prob_per_action else: advantages_mean = reduce_mean_ignore_inf(advantages_or_q_values, 1) advantages_centered = advantages_or_q_values - torch.unsqueeze( advantages_mean, 1) q_values = state_value + advantages_centered else: q_values = advantages_or_q_values # return q_values return q_values, logits, dist
def compute_q_values(policy: Policy, model: ModelV2, obs: TensorType, explore, is_training: bool = False): config = policy.config model_out, state = model( { SampleBatch.CUR_OBS: obs, "is_training": is_training, }, [], None) if config["num_atoms"] > 1: (action_scores, z, support_logits_per_action, logits, probs_or_logits) = model.get_q_value_distributions(model_out) else: (action_scores, logits, probs_or_logits) = model.get_q_value_distributions(model_out) if config["dueling"]: state_score = model.get_state_value(model_out) if policy.config["num_atoms"] > 1: support_logits_per_action_mean = torch.mean( support_logits_per_action, dim=1) support_logits_per_action_centered = ( support_logits_per_action - torch.unsqueeze(support_logits_per_action_mean, dim=1)) support_logits_per_action = torch.unsqueeze( state_score, dim=1) + support_logits_per_action_centered support_prob_per_action = nn.functional.softmax( support_logits_per_action) value = torch.sum(z * support_prob_per_action, dim=-1) logits = support_logits_per_action probs_or_logits = support_prob_per_action else: advantages_mean = reduce_mean_ignore_inf(action_scores, 1) advantages_centered = action_scores - torch.unsqueeze( advantages_mean, 1) value = state_score + advantages_centered else: value = action_scores return value, logits, probs_or_logits
def compute_q_values(policy, model, obs, explore, is_training=False): if policy.config["num_atoms"] > 1: raise ValueError("torch DQN does not support distributional DQN yet!") model_out, state = model({ SampleBatch.CUR_OBS: obs, "is_training": is_training, }, [], None) advantages_or_q_values = model.get_advantages_or_q_values(model_out) if policy.config["dueling"]: state_value = model.get_state_value(model_out) advantages_mean = reduce_mean_ignore_inf(advantages_or_q_values, 1) advantages_centered = advantages_or_q_values - torch.unsqueeze( advantages_mean, 1) q_values = state_value + advantages_centered else: q_values = advantages_or_q_values return q_values