def __init__(self,
                 sample_spec: tf.TensorSpec,
                 num_actions: int,
                 name: str = 'OneHotCategoricalProjectionNetwork') -> None:
        """
        Initialises a head of a multi-headed action network which will return a distribution over
        actions related to this head.

        :param sample_spec: A specification for the results of sampling from this head of the
            policy.
        :param num_actions: The dimensionality of this head's action space.
        :param name: A name fo this head.
        """
        output_shape = (num_actions, )
        output_spec = self._output_distribution_spec(output_shape, sample_spec,
                                                     name)
        super(OneHotCategoricalProjectionNetwork, self).__init__(
            # We don't need these, but base class requires them.
            input_tensor_spec=None,
            state_spec=(),
            output_spec=output_spec,
            name=name)
        self._projection_layer = tf.keras.layers.Dense(num_actions,
                                                       activation=None)

        if not tensor_spec.is_bounded(sample_spec):
            raise ValueError('sample_spec must be bounded. Got: %s.' %
                             type(sample_spec))

        self._sample_spec = sample_spec
        self._output_shape = tf.TensorShape(output_shape)
    def __init__(self,
                 time_step_spec=None,
                 action_spec=None,
                 reward_network=None,
                 observation_and_action_constraint_splitter=None,
                 name=None):
        """Builds a GreedyRewardPredictionPolicy given a reward tf_agents.Network.

    This policy takes a tf_agents.Network predicting rewards and generates the
    action corresponding to the largest predicted reward.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      reward_network: An instance of a `tf_agents.network.Network`,
        callable via `network(observation, step_type) -> (output, final_state)`.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the network and 2) the
        mask.  The mask should be a 0-1 `Tensor` of shape
        `[batch_size, num_actions]`. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      name: The name of this policy. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      NotImplementedError: If `action_spec` contains more than one
        `BoundedTensorSpec` or the `BoundedTensorSpec` is not valid.
    """
        self._observation_and_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
        flat_action_spec = tf.nest.flatten(action_spec)
        if len(flat_action_spec) > 1:
            raise NotImplementedError(
                'action_spec can only contain a single BoundedTensorSpec.')

        action_spec = flat_action_spec[0]
        if (not tensor_spec.is_bounded(action_spec)
                or not tensor_spec.is_discrete(action_spec)
                or action_spec.shape.rank > 1
                or action_spec.shape.num_elements() != 1):
            raise NotImplementedError(
                'action_spec must be a BoundedTensorSpec of type int32 and shape (). '
                'Found {}.'.format(action_spec))
        self._expected_num_actions = action_spec.maximum - action_spec.minimum + 1
        self._action_offset = action_spec.minimum
        self._reward_network = reward_network
        super(GreedyRewardPredictionPolicy,
              self).__init__(time_step_spec,
                             action_spec,
                             policy_state_spec=reward_network.state_spec,
                             clip=False,
                             name=name)
Пример #3
0
 def __init__(self, env):
   super(ActionOffsetWrapper, self).__init__(env)
   if tf.nest.is_nested(self._env.action_spec()):
     raise ValueError('ActionOffsetWrapper only works with single-array '
                      'action specs (not nested specs).')
   if not tensor_spec.is_bounded(self._env.action_spec()):
     raise ValueError('ActionOffsetWrapper only works with bounded '
                      'action specs.')
   if not tensor_spec.is_discrete(self._env.action_spec()):
     raise ValueError('ActionOffsetWrapper only works with discrete '
                      'action specs.')
    def __init__(self,
                 sample_spec,
                 logits_init_output_factor=0.1,
                 name='CategoricalProjectionNetwork'):
        """Creates an instance of CategoricalProjectionNetwork.

    Args:
      sample_spec: An spec (either BoundedArraySpec or BoundedTensorSpec)
        detailing the shape and dtypes of samples pulled from the output
        distribution.
      logits_init_output_factor: Output factor for initializing kernel logits
        weights.
      name: A string representing name of the network.
    """
        unique_num_actions = np.unique(sample_spec.maximum -
                                       sample_spec.minimum + 1)
        if len(unique_num_actions) > 1 or np.any(unique_num_actions <= 0):
            raise ValueError(
                'Bounds on discrete actions must be the same for all '
                'dimensions and have at least 1 action.')

        output_shape = sample_spec.shape.concatenate([unique_num_actions])
        output_spec = self._output_distribution_spec(output_shape, sample_spec)

        super(CategoricalProjectionNetwork, self).__init__(
            # We don't need these, but base class requires them.
            input_tensor_spec=None,
            state_spec=(),
            output_spec=output_spec,
            name=name)

        if not tensor_spec.is_bounded(sample_spec):
            raise ValueError('sample_spec must be bounded. Got: %s.' %
                             type(sample_spec))

        if not tensor_spec.is_discrete(sample_spec):
            raise ValueError('sample_spec must be discrete. Got: %s.' %
                             sample_spec)

        if len(unique_num_actions) > 1:
            raise ValueError(
                'Projection Network requires num_actions to be equal '
                'across action dimentions. Implement a more general categorical '
                'projection if you need more flexibility.')

        self._sample_spec = sample_spec
        self._output_shape = output_shape

        self._projection_layer = tf.keras.layers.Dense(
            self._output_shape.num_elements(),
            kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
                scale=logits_init_output_factor),
            bias_initializer=tf.keras.initializers.Zeros(),
            name='logits')
Пример #5
0
def factored_categorical(inputs,
                         output_spec,
                         outer_rank=1,
                         projection_layer=default_fully_connected):
    """Project a batch of inputs to a categorical distribution.

  Given an output spec for a single tensor discrete action, produces a
  neural net layer converting inputs to a categorical distribution
  matching the spec. The logits are derived from a fully connected linear
  layer. Each discrete action (each element of the output tensor) is sampled
  independently.

  Args:
    inputs: An input Tensor of shape [batch_size, ?].
    output_spec: An output spec (either BoundedArraySpec or BoundedTensorSpec).
    outer_rank: The number of outer dimensions of inputs to consider batch
      dimensions and to treat as batch dimensions of output distribution.
    projection_layer: Function taking in inputs, num_elements, scope and
      returning a projection of inputs to a Tensor of width num_elements.

  Returns:
    A tf.distribution.Categorical object.

  Raises:
    ValueError: If output_spec contains multiple distinct ranges or is otherwise
      invalid.
  """
    if not tensor_spec.is_bounded(output_spec):
        raise ValueError('Input output_spec is of invalid type '
                         '%s.' % type(output_spec))
    if not tensor_spec.is_discrete(output_spec):
        raise ValueError('Output is not discrete.')
    num_outputs = np.unique(output_spec.maximum - output_spec.minimum + 1)
    num_ranges = len(num_outputs)
    if num_ranges > 1 or np.any(num_outputs <= 0):
        raise ValueError('Single discrete output has invalid ranges: '
                         '%s' % num_outputs)
    output_shape = output_spec.shape.concatenate([num_outputs])
    batch_squash = utils.BatchSquash(outer_rank)
    inputs = batch_squash.flatten(inputs)
    logits = projection_layer(inputs,
                              output_shape.num_elements(),
                              scope='logits')
    logits = tf.reshape(logits, [-1] + output_shape.as_list())
    logits = batch_squash.unflatten(logits)
    return tfp.distributions.Categorical(logits, dtype=output_spec.dtype)
Пример #6
0
    def __init__(self,
                 time_step_spec=None,
                 action_spec=None,
                 reward_network=None,
                 name=None):
        """Builds a GreedyRewardPredictionPolicy given a reward tf_agents.Network.

    This policy takes a tf_agents.Network predicting rewards and generates the
    action corresponding to the largest predicted reward.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      reward_network: An instance of a `tf_agents.network.Network`,
        callable via `network(observation, step_type) -> (output, final_state)`.
      name: The name of this policy. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      NotImplementedError: If `action_spec` contains more than one
        `BoundedTensorSpec` or the `BoundedTensorSpec` is not valid.
    """
        flat_action_spec = tf.nest.flatten(action_spec)
        if len(flat_action_spec) > 1:
            raise NotImplementedError(
                'action_spec can only contain a single BoundedTensorSpec.')

        action_spec = flat_action_spec[0]
        if (not tensor_spec.is_bounded(action_spec)
                or not tensor_spec.is_discrete(action_spec)
                or action_spec.shape.rank > 1
                or action_spec.shape.num_elements() != 1):
            raise NotImplementedError(
                'action_spec must be a BoundedTensorSpec of type int32 and shape (). '
                'Found {}.'.format(action_spec))
        self._expected_num_actions = action_spec.maximum - action_spec.minimum + 1
        self._action_offset = action_spec.minimum
        self._reward_network = reward_network
        super(GreedyRewardPredictionPolicy,
              self).__init__(time_step_spec,
                             action_spec,
                             policy_state_spec=reward_network.state_spec,
                             clip=False,
                             name=name)
    def __init__(self,
                 time_step_spec=None,
                 action_spec=None,
                 reward_network=None,
                 observation_and_action_constraint_splitter=None,
                 emit_policy_info=(),
                 name=None):
        """Builds a GreedyRewardPredictionPolicy given a reward tf_agents.Network.

    This policy takes a tf_agents.Network predicting rewards and generates the
    action corresponding to the largest predicted reward.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      reward_network: An instance of a `tf_agents.network.Network`,
        callable via `network(observation, step_type) -> (output, final_state)`.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the network and 2) the
        mask.  The mask should be a 0-1 `Tensor` of shape
        `[batch_size, num_actions]`. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      name: The name of this policy. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      NotImplementedError: If `action_spec` contains more than one
        `BoundedTensorSpec` or the `BoundedTensorSpec` is not valid.
    """
        flat_action_spec = tf.nest.flatten(action_spec)
        if len(flat_action_spec) > 1:
            raise NotImplementedError(
                'action_spec can only contain a single BoundedTensorSpec.')

        action_spec = flat_action_spec[0]
        if (not tensor_spec.is_bounded(action_spec)
                or not tensor_spec.is_discrete(action_spec)
                or action_spec.shape.rank > 1
                or action_spec.shape.num_elements() != 1):
            raise NotImplementedError(
                'action_spec must be a BoundedTensorSpec of type int32 and shape (). '
                'Found {}.'.format(action_spec))
        self._expected_num_actions = action_spec.maximum - action_spec.minimum + 1
        self._action_offset = action_spec.minimum
        reward_network.create_variables()
        self._reward_network = reward_network

        self._emit_policy_info = emit_policy_info
        predicted_rewards_mean = ()
        if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in emit_policy_info:
            predicted_rewards_mean = tensor_spec.TensorSpec(
                [self._expected_num_actions])
        bandit_policy_type = ()
        if policy_utilities.InfoFields.BANDIT_POLICY_TYPE in emit_policy_info:
            bandit_policy_type = (
                policy_utilities.create_bandit_policy_type_tensor_spec(
                    shape=[1]))
        info_spec = policy_utilities.PolicyInfo(
            predicted_rewards_mean=predicted_rewards_mean,
            bandit_policy_type=bandit_policy_type)

        super(GreedyRewardPredictionPolicy,
              self).__init__(time_step_spec,
                             action_spec,
                             policy_state_spec=reward_network.state_spec,
                             clip=False,
                             info_spec=info_spec,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter),
                             name=name)
Пример #8
0
def _is_categorical_spec(spec):
    return (tensor_spec.is_discrete(spec) and tensor_spec.is_bounded(spec)
            and spec.shape == [] and spec.minimum == 0)
Пример #9
0
    def __init__(self,
                 time_step_spec: types.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 reward_network: types.Network,
                 temperature: types.FloatOrReturningFloat = 1.0,
                 observation_and_action_constraint_splitter: Optional[
                     types.Splitter] = None,
                 accepts_per_arm_features: bool = False,
                 constraints: Tuple[constr.NeuralConstraint, ...] = (),
                 emit_policy_info: Tuple[Text, ...] = (),
                 name: Optional[Text] = None):
        """Builds a BoltzmannRewardPredictionPolicy given a reward network.

    This policy takes a tf_agents.Network predicting rewards and chooses an
    action with weighted probabilities (i.e., using a softmax over the network
    estimates of value for each action).

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      reward_network: An instance of a `tf_agents.network.Network`,
        callable via `network(observation, step_type) -> (output, final_state)`.
      temperature: float or callable that returns a float. The temperature used
        in the Boltzmann exploration.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the network and 2) the
        mask.  The mask should be a 0-1 `Tensor` of shape
        `[batch_size, num_actions]`. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      accepts_per_arm_features: (bool) Whether the policy accepts per-arm
        features.
      constraints: iterable of constraints objects that are instances of
        `tf_agents.bandits.agents.NeuralConstraint`.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      name: The name of this policy. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      NotImplementedError: If `action_spec` contains more than one
        `BoundedTensorSpec` or the `BoundedTensorSpec` is not valid.
    """
        policy_utilities.check_no_mask_with_arm_features(
            accepts_per_arm_features,
            observation_and_action_constraint_splitter)
        flat_action_spec = tf.nest.flatten(action_spec)
        if len(flat_action_spec) > 1:
            raise NotImplementedError(
                'action_spec can only contain a single BoundedTensorSpec.')

        self._temperature = temperature
        action_spec = flat_action_spec[0]
        if (not tensor_spec.is_bounded(action_spec)
                or not tensor_spec.is_discrete(action_spec)
                or action_spec.shape.rank > 1
                or action_spec.shape.num_elements() != 1):
            raise NotImplementedError(
                'action_spec must be a BoundedTensorSpec of type int32 and shape (). '
                'Found {}.'.format(action_spec))
        self._expected_num_actions = action_spec.maximum - action_spec.minimum + 1
        self._action_offset = action_spec.minimum
        reward_network.create_variables()
        self._reward_network = reward_network
        self._constraints = constraints

        self._emit_policy_info = emit_policy_info
        predicted_rewards_mean = ()
        if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in emit_policy_info:
            predicted_rewards_mean = tensor_spec.TensorSpec(
                [self._expected_num_actions])
        bandit_policy_type = ()
        if policy_utilities.InfoFields.BANDIT_POLICY_TYPE in emit_policy_info:
            bandit_policy_type = (
                policy_utilities.create_bandit_policy_type_tensor_spec(
                    shape=[1]))
        if accepts_per_arm_features:
            # The features for the chosen arm is saved to policy_info.
            chosen_arm_features_info = (
                policy_utilities.create_chosen_arm_features_info_spec(
                    time_step_spec.observation))
            info_spec = policy_utilities.PerArmPolicyInfo(
                predicted_rewards_mean=predicted_rewards_mean,
                bandit_policy_type=bandit_policy_type,
                chosen_arm_features=chosen_arm_features_info)
        else:
            info_spec = policy_utilities.PolicyInfo(
                predicted_rewards_mean=predicted_rewards_mean,
                bandit_policy_type=bandit_policy_type)

        self._accepts_per_arm_features = accepts_per_arm_features

        super(BoltzmannRewardPredictionPolicy,
              self).__init__(time_step_spec,
                             action_spec,
                             policy_state_spec=reward_network.state_spec,
                             clip=False,
                             info_spec=info_spec,
                             emit_log_probability='log_probability'
                             in emit_policy_info,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter),
                             name=name)
    def __init__(
            self,
            time_step_spec: Optional[ts.TimeStep],
            action_spec: Optional[NestedBoundedTensorSpec],
            scalarizer: multi_objective_scalarizer.Scalarizer,
            objective_networks: Sequence[Network],
            observation_and_action_constraint_splitter: types.Splitter = None,
            accepts_per_arm_features: bool = False,
            emit_policy_info: Tuple[Text] = (),
            name: Optional[Text] = None):
        """Builds a GreedyMultiObjectiveNeuralPolicy based on multiple networks.

    This policy takes an iterable of `tf_agents.Network`, each responsible for
    predicting a specific objective, along with a `Scalarizer` object to
    generate an action by maximizing the scalarized objective, i.e., the output
    of the `Scalarizer` applied to the multiple predicted objectives by the
    networks.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of `BoundedTensorSpec` representing the actions.
      scalarizer: A
       `tf_agents.bandits.multi_objective.multi_objective_scalarizer.Scalarizer`
        object that implements scalarization of multiple objectives into a
        single scalar reward.
      objective_networks: A Sequence of `tf_agents.network.Network` objects to
        be used by the policy. Each network will be called with
        call(observation, step_type) and is expected to provide a prediction for
        a specific objective for all actions.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the network and 2) the
        mask.  The mask should be a 0-1 `Tensor` of shape `[batch_size,
        num_actions]`. This function should also work with a `TensorSpec` as
        input, and should output `TensorSpec` objects for the observation and
        mask.
      accepts_per_arm_features: (bool) Whether the policy accepts per-arm
        features.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      name: The name of this policy. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      NotImplementedError: If `action_spec` contains more than one
        `BoundedTensorSpec` or the `BoundedTensorSpec` is not valid.
      NotImplementedError: If `action_spec` is not a `BoundedTensorSpec` of type
        int32 and shape ().
      ValueError: If `objective_networks` has fewer than two networks.
      ValueError: If `accepts_per_arm_features` is true but `time_step_spec` is
        None.
    """
        flat_action_spec = tf.nest.flatten(action_spec)
        if len(flat_action_spec) > 1:
            raise NotImplementedError(
                'action_spec can only contain a single BoundedTensorSpec.')

        action_spec = flat_action_spec[0]
        if (not tensor_spec.is_bounded(action_spec)
                or not tensor_spec.is_discrete(action_spec)
                or action_spec.shape.rank > 1
                or action_spec.shape.num_elements() != 1):
            raise NotImplementedError(
                'action_spec must be a BoundedTensorSpec of type int32 and shape (). '
                'Found {}.'.format(action_spec))
        self._expected_num_actions = action_spec.maximum - action_spec.minimum + 1
        self._action_offset = action_spec.minimum
        policy_state_spec = []
        for network in objective_networks:
            policy_state_spec.append(network.state_spec)
            network.create_variables()
        self._objective_networks = objective_networks
        self._scalarizer = scalarizer
        self._num_objectives = len(self._objective_networks)
        if self._num_objectives < 2:
            raise ValueError(
                'Number of objectives should be at least two, but found to be {}'
                .format(self._num_objectives))

        self._emit_policy_info = emit_policy_info
        predicted_rewards_mean = ()
        if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in emit_policy_info:
            predicted_rewards_mean = tensor_spec.TensorSpec(
                [self._num_objectives, self._expected_num_actions])
        bandit_policy_type = ()
        if policy_utilities.InfoFields.BANDIT_POLICY_TYPE in emit_policy_info:
            bandit_policy_type = (
                policy_utilities.create_bandit_policy_type_tensor_spec(
                    shape=[1]))
        if accepts_per_arm_features:
            if time_step_spec is None:
                raise ValueError(
                    'time_step_spec should not be None for per-arm-features policies, '
                    'but found to be.')
            # The features for the chosen arm is saved to policy_info.
            chosen_arm_features_info = (
                policy_utilities.create_chosen_arm_features_info_spec(
                    time_step_spec.observation,
                    observation_and_action_constraint_splitter))
            info_spec = policy_utilities.PerArmPolicyInfo(
                predicted_rewards_mean=predicted_rewards_mean,
                bandit_policy_type=bandit_policy_type,
                chosen_arm_features=chosen_arm_features_info)
        else:
            info_spec = policy_utilities.PolicyInfo(
                predicted_rewards_mean=predicted_rewards_mean,
                bandit_policy_type=bandit_policy_type)

        self._accepts_per_arm_features = accepts_per_arm_features

        super(GreedyMultiObjectiveNeuralPolicy,
              self).__init__(time_step_spec,
                             action_spec,
                             policy_state_spec=policy_state_spec,
                             clip=False,
                             info_spec=info_spec,
                             emit_log_probability='log_probability'
                             in emit_policy_info,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter),
                             name=name)
Пример #11
0
def normal(inputs,
           output_spec,
           outer_rank=1,
           projection_layer=default_fully_connected,
           mean_transform=tanh_squash_to_spec,
           std_initializer=tf.zeros_initializer(),
           std_transform=tf.exp,
           distribution_cls=tfp.distributions.Normal):
    """Project a batch of inputs to a batch of means and standard deviations.

  Given an output spec for a single tensor continuous action, produces a
  neural net layer converting inputs to a normal distribution matching
  the spec.  The mean is derived from a fully connected linear layer as
  mean_transform(layer_output, output_spec).  The std is fixed to a single
  trainable tensor (thus independent of the inputs).  Specifically, std is
  parameterized as std_transform(variable).

  Args:
    inputs: An input Tensor of shape [batch_size, ?].
    output_spec: An output spec (either BoundedArraySpec or BoundedTensorSpec).
    outer_rank: The number of outer dimensions of inputs to consider batch
      dimensions and to treat as batch dimensions of output distribution.
    projection_layer: Function taking in inputs, num_elements, scope and
      returning a projection of inputs to a Tensor of width num_elements.
    mean_transform: A function taking in layer output and the output_spec,
      returning the means.  Defaults to tanh_squash_to_spec.
    std_initializer: Initializer for std_dev variables.
    std_transform: The function applied to the trainable std variable. For
      example, tf.exp (default), tf.nn.softplus.
    distribution_cls: The distribution class to use for output distribution.
      Default is tfp.distributions.Normal.

  Returns:
    A tf.distribution.Normal object in which the standard deviation is not
      dependent on input.

  Raises:
    ValueError: If output_spec is invalid.
  """
    if not tensor_spec.is_bounded(output_spec):
        raise ValueError('Input output_spec is of invalid type '
                         '%s.' % type(output_spec))
    if not tensor_spec.is_continuous(output_spec):
        raise ValueError('Output is not continuous.')

    batch_squash = utils.BatchSquash(outer_rank)
    inputs = batch_squash.flatten(inputs)
    means = projection_layer(inputs,
                             output_spec.shape.num_elements(),
                             scope='means')
    stds = tf.contrib.layers.bias_add(
        tf.zeros_like(means),  # Independent of inputs.
        initializer=std_initializer,
        scope='stds',
        activation_fn=None)

    means = tf.reshape(means, [-1] + output_spec.shape.as_list())
    means = mean_transform(means, output_spec)
    means = tf.cast(means, output_spec.dtype)

    stds = tf.reshape(stds, [-1] + output_spec.shape.as_list())
    stds = std_transform(stds)
    stds = tf.cast(stds, output_spec.dtype)

    means, stds = batch_squash.unflatten(means), batch_squash.unflatten(stds)
    return distribution_cls(means, stds)
    def __init__(self,
                 time_step_spec: types.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 alpha: Sequence[tf.Variable],
                 beta: Sequence[tf.Variable],
                 observation_and_action_constraint_splitter: Optional[
                     types.Splitter] = None,
                 emit_policy_info: Sequence[Text] = (),
                 name: Optional[Text] = None):
        """Builds a BernoulliThompsonSamplingPolicy.

    For a reference, see e.g., Chapter 3 in "A Tutorial on Thompson Sampling" by
    Russo et al. (https://web.stanford.edu/~bvr/pubs/TS_Tutorial.pdf).

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      alpha: list or tuple of tf.Variable's. It holds the `alpha` parameter of
        the beta distribution of each arm.
      beta: list or tuple of tf.Variable's. It holds the `beta` parameter of the
        beta distribution of each arm.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the network and 2) the
        mask.  The mask should be a 0-1 `Tensor` of shape
        `[batch_size, num_actions]`. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      name: The name of this policy. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      NotImplementedError: If `action_spec` contains more than one
        `BoundedTensorSpec` or the `BoundedTensorSpec` is not valid.
    """
        flat_action_spec = tf.nest.flatten(action_spec)
        if len(flat_action_spec) > 1:
            raise NotImplementedError(
                'action_spec can only contain a single BoundedTensorSpec.')

        action_spec = flat_action_spec[0]
        if (not tensor_spec.is_bounded(action_spec)
                or not tensor_spec.is_discrete(action_spec)
                or action_spec.shape.rank > 1
                or action_spec.shape.num_elements() != 1):
            raise NotImplementedError(
                'action_spec must be a BoundedTensorSpec of integer type and '
                'shape (). Found {}.'.format(action_spec))
        self._expected_num_actions = action_spec.maximum - action_spec.minimum + 1

        if len(alpha) != self._expected_num_actions:
            raise ValueError(
                'The size of alpha parameters is expected to be equal '
                'to the number of actions, but found to be {}'.format(
                    len(alpha)))
        self._alpha = alpha
        if len(alpha) != len(beta):
            raise ValueError(
                'The size of alpha parameters is expected to be equal '
                'to the size of beta parameters')
        self._beta = beta

        self._emit_policy_info = emit_policy_info
        predicted_rewards_mean = ()
        if policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN in emit_policy_info:
            predicted_rewards_mean = tensor_spec.TensorSpec(
                [self._expected_num_actions])
        predicted_rewards_sampled = ()
        if policy_utilities.InfoFields.PREDICTED_REWARDS_SAMPLED in (
                emit_policy_info):
            predicted_rewards_sampled = tensor_spec.TensorSpec(
                [self._expected_num_actions])
        info_spec = policy_utilities.PolicyInfo(
            predicted_rewards_mean=predicted_rewards_mean,
            predicted_rewards_sampled=predicted_rewards_sampled)

        super(BernoulliThompsonSamplingPolicy,
              self).__init__(time_step_spec,
                             action_spec,
                             info_spec=info_spec,
                             emit_log_probability='log_probability'
                             in emit_policy_info,
                             observation_and_action_constraint_splitter=(
                                 observation_and_action_constraint_splitter),
                             name=name)