Beispiel #1
0
def compute_q_values(policy: Policy,
                     model: ModelV2,
                     obs: TensorType,
                     explore,
                     is_training=None) -> TensorType:
    _is_training = (is_training if is_training is not None else
                    policy._get_is_training_placeholder())
    model_out, _ = model(SampleBatch(obs=obs, _is_training=_is_training), [],
                         None)

    return model_out
Beispiel #2
0
def compute_q_values(policy: Policy,
                     model: ModelV2,
                     obs: TensorType,
                     explore,
                     is_training=None) -> TensorType:
    model_out, _ = model({
        SampleBatch.CUR_OBS: obs,
        "is_training": is_training
        if is_training is not None else policy._get_is_training_placeholder(),
    }, [], None)

    return model_out