Пример #1
0
  def recurrent_inference(self, hidden_state, action, training=True):
    if self._embed_actions:
      one_hot_action = tf.one_hot(
          action, self._parametric_action_distribution.param_size)
      embedded_action = self._action_embeddings(one_hot_action)
    else:
      one_hot_action = tf.one_hot(
          action, self._parametric_action_distribution.param_size, 1., -1.)
      embedded_action = one_hot_action
    hidden_state = self._maybe_normalize_hidden_state(hidden_state)

    rnn_state = self._flat_to_rnn(hidden_state)
    rnn_output, next_rnn_state = self._core(embedded_action, rnn_state)
    next_hidden_state = self._rnn_to_flat(next_rnn_state)

    value_logits = self._value_head(next_hidden_state, training=training)
    value = self.value_encoder.decode(value_logits)

    reward_logits = self._reward_head(rnn_output, training=training)
    reward = self.reward_encoder.decode(reward_logits)

    policy_logits = self._policy_head(next_hidden_state, training=training)

    output = mzcore.NetworkOutput(
        value=value,
        value_logits=value_logits,
        reward=reward,
        reward_logits=reward_logits,
        policy_logits=policy_logits,
        hidden_state=next_hidden_state)
    return output
Пример #2
0
def send_initial_inference_request(
    predict_service: prediction_service_pb2_grpc.PredictionServiceStub,
    inputs: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
) -> core.NetworkOutput:
  """Initial inference for the agent, used at the beginning of MCTS."""
  input_ids, input_type_ids, input_features, action_history = inputs

  request = predict_pb2.PredictRequest()
  request.model_spec.name = FLAGS.initial_inference_model_name
  request.model_spec.signature_name = 'initial_inference'

  request.inputs['input_ids'].CopyFrom(
      tf.make_tensor_proto(values=np.expand_dims(input_ids, axis=0)))
  request.inputs['segment_ids'].CopyFrom(
      tf.make_tensor_proto(values=np.expand_dims(input_type_ids, axis=0)))
  request.inputs['features'].CopyFrom(
      tf.make_tensor_proto(values=np.expand_dims(input_features, axis=0)))
  request.inputs['action_history'].CopyFrom(
      tf.make_tensor_proto(values=np.expand_dims(action_history, axis=0)))
  response = predict_service.Predict(request)

  # Parse and `unbatch` the response.
  map_names = {
      f'output_{i}': v for (i, v) in enumerate([
          'value', 'value_logits', 'reward', 'reward_logits', 'policy_logits',
          'hidden_state'
      ])
  }
  outputs = {
      map_names[k]: tf.make_ndarray(v).squeeze()
      for k, v in response.outputs.items()
  }

  return core.NetworkOutput(**outputs)
Пример #3
0
def send_recurrent_inference_request(
    hidden_state: np.ndarray, action: np.ndarray,
    predict_service: prediction_service_pb2_grpc.PredictionServiceStub
) -> core.NetworkOutput:
  """Recurrent inference for the agent, used during MCTS."""
  request = predict_pb2.PredictRequest()
  request.model_spec.name = FLAGS.recurrent_inference_model_name
  request.model_spec.signature_name = 'recurrent_inference'

  request.inputs['hidden_state'].CopyFrom(
      tf.make_tensor_proto(values=tf.expand_dims(hidden_state, axis=0)))
  request.inputs['action'].CopyFrom(
      tf.make_tensor_proto(
          values=np.expand_dims(action, axis=0).astype(np.int32)))
  response = predict_service.Predict(request)

  # Parse and `unbatch` the response.
  map_names = {
      f'output_{i}': v for (i, v) in enumerate([
          'value', 'value_logits', 'reward', 'reward_logits', 'policy_logits',
          'hidden_state'
      ])
  }
  outputs = {
      map_names[k]: tf.make_ndarray(v).squeeze()
      for k, v in response.outputs.items()
  }

  return core.NetworkOutput(**outputs)
Пример #4
0
    def initial_inference(self, observation, training=True):
        encoded_observation = self._encode_observation(observation,
                                                       training=training)
        hidden_state = self._to_hidden(encoded_observation, training=training)

        value_logits = self._value_head(hidden_state, training=training)
        value = self.value_encoder.decode(tf.nn.softmax(value_logits))

        # Rewards are only calculated in recurrent_inference.
        reward = tf.zeros_like(value)
        reward_logits = self.reward_encoder.encode(reward)

        policy_logits = self._policy_head(hidden_state, training=training)

        outputs = mzcore.NetworkOutput(value_logits=value_logits,
                                       value=value,
                                       reward_logits=reward_logits,
                                       reward=reward,
                                       policy_logits=policy_logits,
                                       hidden_state=hidden_state)
        return outputs