def __init__(
            self,
            time_step_spec: Optional[ts.TimeStep],
            action_spec: Optional[types.NestedBoundedTensorSpec],
            scalarizer: multi_objective_scalarizer.Scalarizer,
            objective_networks: Sequence[Network],
            optimizer: tf.keras.optimizers.Optimizer,
            observation_and_action_constraint_splitter: types.Splitter = None,
            accepts_per_arm_features: bool = False,
            # Params for training.
            error_loss_fn: Callable[
                ..., tf.Tensor] = tf.compat.v1.losses.mean_squared_error,
            gradient_clipping: Optional[float] = None,
            # Params for debugging.
            debug_summaries: bool = False,
            summarize_grads_and_vars: bool = False,
            enable_summaries: bool = True,
            emit_policy_info: Tuple[Text] = (),
            train_step_counter: Optional[tf.Variable] = None,
            name: Optional[Text] = None):
        """Creates a Greedy Multi-objective Neural Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of `BoundedTensorSpec` representing the actions.
      scalarizer: A
       `tf_agents.bandits.multi_objective.multi_objective_scalarizer.Scalarizer`
        object that implements scalarization of multiple objectives into a
        single scalar reward.
      objective_networks: A Sequence of `tf_agents.network.Network` objects to
        be used by the agent. Each network will be called with
        call(observation, step_type) and is expected to provide a prediction for
        a specific objective for all actions.
      optimizer: A 'tf.keras.optimizers.Optimizer' object, the optimizer to use
        for training.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the bandit agent and
        policy, and 2) the boolean mask of shape `[batch_size, num_actions]`.
        This function should also work with a `TensorSpec` as input, and should
        output `TensorSpec` objects for the observation and mask.
      accepts_per_arm_features: (bool) Whether the agent accepts per-arm
        features.
      error_loss_fn: A function for computing the error loss, taking parameters
        labels, predictions, and weights (any function from tf.losses would
        work). The default is `tf.losses.mean_squared_error`.
      gradient_clipping: A float representing the norm length to clip gradients
        (or None for no clipping.)
      debug_summaries: A Python bool, default False. When True, debug summaries
        are gathered.
      summarize_grads_and_vars: A Python bool, default False. When True,
        gradients and network variable summaries are written during training.
      enable_summaries: A Python bool, default True. When False, all summaries
        (debug or otherwise) should not be written.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      train_step_counter: An optional `tf.Variable` to increment every time the
        train op is run.  Defaults to the `global_step`.
      name: Python str name of this agent. All variables in this module will
        fall under that name. Defaults to the class name.

    Raises:
      ValueError:
        - If the action spec contains more than one action or or it is not a
          bounded scalar int32 spec with minimum 0.
        - If `objective_networks` has fewer than two networks.
    """
        tf.Module.__init__(self, name=name)
        common.tf_agents_gauge.get_cell('TFABandit').set(True)
        self._observation_and_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
        self._num_actions = policy_utilities.get_num_actions_from_tensor_spec(
            action_spec)
        self._accepts_per_arm_features = accepts_per_arm_features

        self._num_objectives = len(objective_networks)
        if self._num_objectives < 2:
            raise ValueError(
                'Number of objectives should be at least two, but found to be {}'
                .format(self._num_objectives))
        self._objective_networks = objective_networks
        self._optimizer = optimizer
        self._error_loss_fn = error_loss_fn
        self._gradient_clipping = gradient_clipping
        self._heteroscedastic = [
            isinstance(network,
                       heteroscedastic_q_network.HeteroscedasticQNetwork)
            for network in objective_networks
        ]

        policy = greedy_multi_objective_policy.GreedyMultiObjectiveNeuralPolicy(
            time_step_spec,
            action_spec,
            scalarizer,
            self._objective_networks,
            observation_and_action_constraint_splitter,
            accepts_per_arm_features=accepts_per_arm_features,
            emit_policy_info=emit_policy_info)
        training_data_spec = None
        if accepts_per_arm_features:
            training_data_spec = bandit_spec_utils.drop_arm_observation(
                policy.trajectory_spec)

        super(GreedyMultiObjectiveNeuralAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy,
                             collect_policy=policy,
                             train_sequence_length=None,
                             training_data_spec=training_data_spec,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             enable_summaries=enable_summaries,
                             train_step_counter=train_step_counter)
  def __init__(
      self,
      time_step_spec,
      action_spec,
      reward_network,
      optimizer,
      observation_and_action_constraint_splitter=None,
      accepts_per_arm_features=False,
      # Params for training.
      error_loss_fn=tf.compat.v1.losses.mean_squared_error,
      gradient_clipping=None,
      # Params for debugging.
      debug_summaries=False,
      summarize_grads_and_vars=False,
      enable_summaries=True,
      emit_policy_info=(),
      train_step_counter=None,
      laplacian_matrix=None,
      laplacian_smoothing_weight=0.001,
      name=None):
    """Creates a Greedy Reward Network Prediction Agent.

     In some use cases, the actions are not independent and they are related to
     each other (e.g., when the actions are ordinal integers). Assuming that
     the relations between arms can be modeled by a graph, we may want to
     enforce that the estimated reward function is smooth over the graph. This
     implies that the estimated rewards `r_i` and `r_j` for two related actions
     `i` and `j`, should be close to each other. To quantify this smoothness
     criterion we use the Laplacian matrix `L` of the graph over the actions.
     When the laplacian smoothing is enabled, the loss is extended to:
     ```
       Loss_new := Loss + lambda r^T * L * r,
     ```
     where `r` is the estimated reward vector for all actions. The second
     term is the laplacian smoothing regularization term and `lambda` is the
     weight that determines how strongly we enforce the regularization.
     For more details, please see:
     "Bandits on graphs and structures", Michal Valko
     https://hal.inria.fr/tel-01359757/document

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of `BoundedTensorSpec` representing the actions.
      reward_network: A `tf_agents.network.Network` to be used by the agent. The
        network will be called with call(observation, step_type) and it is
        expected to provide a reward prediction for all actions.
      optimizer: The optimizer to use for training.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the bandit agent and
        policy, and 2) the boolean mask. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      accepts_per_arm_features: (bool) Whether the policy accepts per-arm
        features.
      error_loss_fn: A function for computing the error loss, taking parameters
        labels, predictions, and weights (any function from tf.losses would
        work). The default is `tf.losses.mean_squared_error`.
      gradient_clipping: A float representing the norm length to clip gradients
        (or None for no clipping.)
      debug_summaries: A Python bool, default False. When True, debug summaries
        are gathered.
      summarize_grads_and_vars: A Python bool, default False. When True,
        gradients and network variable summaries are written during training.
      enable_summaries: A Python bool, default True. When False, all summaries
        (debug or otherwise) should not be written.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      train_step_counter: An optional `tf.Variable` to increment every time the
        train op is run.  Defaults to the `global_step`.
      laplacian_matrix: A float `Tensor` or a numpy array shaped
        `[num_actions, num_actions]`. This holds the Laplacian matrix used to
        regularize the smoothness of the estimated expected reward function.
        This only applies to problems where the actions have a graph structure.
        If `None`, the regularization is not applied.
      laplacian_smoothing_weight: A float that determines the weight of the
        regularization term. Note that this has no effect if `laplacian_matrix`
        above is `None`.
      name: Python str name of this agent. All variables in this module will
        fall under that name. Defaults to the class name.

    Raises:
      ValueError: If the action spec contains more than one action or or it is
      not a bounded scalar int32 spec with minimum 0.
      InvalidArgumentError: if the Laplacian provided is not None and not valid.
    """
    tf.Module.__init__(self, name=name)
    common.tf_agents_gauge.get_cell('TFABandit').set(True)
    self._observation_and_action_constraint_splitter = (
        observation_and_action_constraint_splitter)
    self._num_actions = bandit_utils.get_num_actions_from_tensor_spec(
        action_spec)
    self._accepts_per_arm_features = accepts_per_arm_features

    reward_network.create_variables()
    self._reward_network = reward_network
    self._optimizer = optimizer
    self._error_loss_fn = error_loss_fn
    self._gradient_clipping = gradient_clipping
    self._heteroscedastic = isinstance(
        reward_network, heteroscedastic_q_network.HeteroscedasticQNetwork)
    self._laplacian_matrix = None
    if laplacian_matrix is not None:
      self._laplacian_matrix = tf.convert_to_tensor(
          laplacian_matrix, dtype=tf.float32)
      # Check the validity of the laplacian matrix.
      tf.debugging.assert_near(
          0.0, tf.norm(tf.reduce_sum(self._laplacian_matrix, 1)))
      tf.debugging.assert_near(
          0.0, tf.norm(tf.reduce_sum(self._laplacian_matrix, 0)))
    self._laplacian_smoothing_weight = laplacian_smoothing_weight

    policy = greedy_reward_policy.GreedyRewardPredictionPolicy(
        time_step_spec,
        action_spec,
        reward_network,
        observation_and_action_constraint_splitter,
        accepts_per_arm_features=accepts_per_arm_features,
        emit_policy_info=emit_policy_info)
    training_data_spec = None
    if accepts_per_arm_features:
      training_data_spec = bandit_spec_utils.drop_arm_observation(
          policy.trajectory_spec)

    super(GreedyRewardPredictionAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy=policy,
        train_sequence_length=None,
        training_data_spec=training_data_spec,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        enable_summaries=enable_summaries,
        train_step_counter=train_step_counter)
예제 #3
0
    def __init__(self,
                 exploration_policy,
                 time_step_spec: types.TimeStep,
                 action_spec: types.BoundedTensorSpec,
                 variable_collection: Optional[
                     LinearBanditVariableCollection] = None,
                 alpha: float = 1.0,
                 gamma: float = 1.0,
                 use_eigendecomp: bool = False,
                 tikhonov_weight: float = 1.0,
                 add_bias: bool = False,
                 emit_policy_info: Sequence[Text] = (),
                 emit_log_probability: bool = False,
                 observation_and_action_constraint_splitter: Optional[
                     types.Splitter] = None,
                 accepts_per_arm_features: bool = False,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 enable_summaries: bool = True,
                 dtype: tf.DType = tf.float32,
                 name: Optional[Text] = None):
        """Initialize an instance of `LinearBanditAgent`.

    Args:
      exploration_policy: An Enum of type `ExplorationPolicy`. The kind of
        policy we use for exploration. Currently supported policies are
        `LinUCBPolicy` and `LinearThompsonSamplingPolicy`.
      time_step_spec: A `TimeStep` spec describing the expected `TimeStep`s.
      action_spec: A scalar `BoundedTensorSpec` with `int32` or `int64` dtype
        describing the number of actions for this agent.
      variable_collection: Instance of `LinearBanditVariableCollection`.
        Collection of variables to be updated by the agent. If `None`, a new
        instance of `LinearBanditVariableCollection` will be created.
      alpha: (float) positive scalar. This is the exploration parameter that
        multiplies the confidence intervals.
      gamma: a float forgetting factor in [0.0, 1.0]. When set to 1.0, the
        algorithm does not forget.
      use_eigendecomp: whether to use eigen-decomposition or not. The default
        solver is Conjugate Gradient.
      tikhonov_weight: (float) tikhonov regularization term.
      add_bias: If true, a bias term will be added to the linear reward
        estimation.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      emit_log_probability: Whether the policy emits log-probabilities or not.
        Since the policy is deterministic, the probability is just 1.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the bandit agent and
        policy, and 2) the boolean mask. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      accepts_per_arm_features: (bool) Whether the agent accepts per-arm
        features.
      debug_summaries: A Python bool, default False. When True, debug summaries
        are gathered.
      summarize_grads_and_vars: A Python bool, default False. When True,
        gradients and network variable summaries are written during training.
      enable_summaries: A Python bool, default True. When False, all summaries
        (debug or otherwise) should not be written.
      dtype: The type of the parameters stored and updated by the agent. Should
        be one of `tf.float32` and `tf.float64`. Defaults to `tf.float32`.
      name: a name for this instance of `LinearBanditAgent`.

    Raises:
      ValueError if dtype is not one of `tf.float32` or `tf.float64`.
      TypeError if variable_collection is not an instance of
        `LinearBanditVariableCollection`.
    """
        tf.Module.__init__(self, name=name)
        common.tf_agents_gauge.get_cell('TFABandit').set(True)
        self._num_actions = policy_utilities.get_num_actions_from_tensor_spec(
            action_spec)
        self._num_models = 1 if accepts_per_arm_features else self._num_actions
        self._observation_and_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
        self._time_step_spec = time_step_spec
        self._accepts_per_arm_features = accepts_per_arm_features
        self._add_bias = add_bias
        if observation_and_action_constraint_splitter is not None:
            context_spec, _ = observation_and_action_constraint_splitter(
                time_step_spec.observation)
        else:
            context_spec = time_step_spec.observation

        (self._global_context_dim,
         self._arm_context_dim) = bandit_spec_utils.get_context_dims_from_spec(
             context_spec, accepts_per_arm_features)
        if self._add_bias:
            # The bias is added via a constant 1 feature.
            self._global_context_dim += 1
        self._overall_context_dim = self._global_context_dim + self._arm_context_dim

        self._alpha = alpha
        if variable_collection is None:
            variable_collection = LinearBanditVariableCollection(
                context_dim=self._overall_context_dim,
                num_models=self._num_models,
                use_eigendecomp=use_eigendecomp,
                dtype=dtype)
        elif not isinstance(variable_collection,
                            LinearBanditVariableCollection):
            raise TypeError('Parameter `variable_collection` should be '
                            'of type `LinearBanditVariableCollection`.')
        self._variable_collection = variable_collection
        self._cov_matrix_list = variable_collection.cov_matrix_list
        self._data_vector_list = variable_collection.data_vector_list
        self._eig_matrix_list = variable_collection.eig_matrix_list
        self._eig_vals_list = variable_collection.eig_vals_list
        # We keep track of the number of samples per arm.
        self._num_samples_list = variable_collection.num_samples_list
        self._gamma = gamma
        if self._gamma < 0.0 or self._gamma > 1.0:
            raise ValueError(
                'Forgetting factor `gamma` must be in [0.0, 1.0].')
        self._dtype = dtype
        if dtype not in (tf.float32, tf.float64):
            raise ValueError(
                'Agent dtype should be either `tf.float32 or `tf.float64`.')
        self._use_eigendecomp = use_eigendecomp
        self._tikhonov_weight = tikhonov_weight

        if exploration_policy == ExplorationPolicy.linear_ucb_policy:
            exploration_strategy = lin_policy.ExplorationStrategy.optimistic
        elif exploration_policy == (
                ExplorationPolicy.linear_thompson_sampling_policy):
            exploration_strategy = lin_policy.ExplorationStrategy.sampling
        else:
            raise ValueError(
                'Linear bandit agent with policy %s not implemented' %
                exploration_policy)
        policy = lin_policy.LinearBanditPolicy(
            action_spec=action_spec,
            cov_matrix=self._cov_matrix_list,
            data_vector=self._data_vector_list,
            num_samples=self._num_samples_list,
            time_step_spec=time_step_spec,
            exploration_strategy=exploration_strategy,
            alpha=alpha,
            eig_vals=self._eig_vals_list if self._use_eigendecomp else (),
            eig_matrix=self._eig_matrix_list if self._use_eigendecomp else (),
            tikhonov_weight=self._tikhonov_weight,
            add_bias=add_bias,
            emit_policy_info=emit_policy_info,
            emit_log_probability=emit_log_probability,
            accepts_per_arm_features=accepts_per_arm_features,
            observation_and_action_constraint_splitter=(
                observation_and_action_constraint_splitter))

        training_data_spec = None
        if accepts_per_arm_features:
            training_data_spec = bandit_spec_utils.drop_arm_observation(
                policy.trajectory_spec)
        super(LinearBanditAgent,
              self).__init__(time_step_spec=time_step_spec,
                             action_spec=action_spec,
                             policy=policy,
                             collect_policy=policy,
                             training_data_spec=training_data_spec,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             enable_summaries=enable_summaries,
                             train_sequence_length=None)
        self._as_trajectory = data_converter.AsTrajectory(self.data_context,
                                                          sequence_length=None)
예제 #4
0
    def __init__(
            self,
            time_step_spec: types.TimeStep,
            action_spec: types.BoundedTensorSpec,
            encoding_network: types.Network,
            encoding_network_num_train_steps: int,
            encoding_dim: int,
            optimizer: types.Optimizer,
            variable_collection: Optional[
                NeuralLinUCBVariableCollection] = None,
            alpha: float = 1.0,
            gamma: float = 1.0,
            epsilon_greedy: float = 0.0,
            observation_and_action_constraint_splitter: Optional[
                types.Splitter] = None,
            accepts_per_arm_features: bool = False,
            distributed_train_encoding_network: bool = False,
            # Params for training.
            error_loss_fn: types.LossFn = tf.compat.v1.losses.
        mean_squared_error,
            gradient_clipping: Optional[float] = None,
            # Params for debugging.
            debug_summaries: bool = False,
            summarize_grads_and_vars: bool = False,
            train_step_counter: Optional[tf.Variable] = None,
            emit_policy_info: Sequence[Text] = (),
            emit_log_probability: bool = False,
            dtype: tf.DType = tf.float64,
            name: Optional[Text] = None):
        """Initialize an instance of `NeuralLinUCBAgent`.

    Args:
      time_step_spec: A `TimeStep` spec describing the expected `TimeStep`s.
      action_spec: A scalar `BoundedTensorSpec` with `int32` or `int64` dtype
        describing the number of actions for this agent.
      encoding_network: a Keras network that encodes the observations.
      encoding_network_num_train_steps: how many training steps to run for
        training the encoding network before switching to LinUCB. If negative,
        the encoding network is assumed to be already trained.
      encoding_dim: the dimension of encoded observations.
      optimizer: The optimizer to use for training.
      variable_collection: Instance of `NeuralLinUCBVariableCollection`.
        Collection of variables to be updated by the agent. If `None`, a new
        instance of `LinearBanditVariables` will be created. Note that this
        collection excludes the variables owned by the encoding network.
      alpha: (float) positive scalar. This is the exploration parameter that
        multiplies the confidence intervals.
      gamma: a float forgetting factor in [0.0, 1.0]. When set to
        1.0, the algorithm does not forget.
      epsilon_greedy: A float representing the probability of choosing a random
        action instead of the greedy action.
      observation_and_action_constraint_splitter: A function used for masking
        valid/invalid actions with each state of the environment. The function
        takes in a full observation and returns a tuple consisting of 1) the
        part of the observation intended as input to the bandit agent and
        policy, and 2) the boolean mask. This function should also work with a
        `TensorSpec` as input, and should output `TensorSpec` objects for the
        observation and mask.
      accepts_per_arm_features: (bool) Whether the policy accepts per-arm
        features.
      distributed_train_encoding_network: (bool) whether to train the encoding
        network or not. This applies only in distributed training setting. When
        set to true this agent will train the encoding network. Otherwise, it
        will assume the encoding network is already trained and will train
        LinUCB on top of it.
      error_loss_fn: A function for computing the error loss, taking parameters
        labels, predictions, and weights (any function from tf.losses would
        work). The default is `tf.losses.mean_squared_error`.
      gradient_clipping: A float representing the norm length to clip gradients
        (or None for no clipping.)
      debug_summaries: A Python bool, default False. When True, debug summaries
        are gathered.
      summarize_grads_and_vars: A Python bool, default False. When True,
        gradients and network variable summaries are written during training.
      train_step_counter: An optional `tf.Variable` to increment every time the
        train op is run.  Defaults to the `global_step`.
      emit_policy_info: (tuple of strings) what side information we want to get
        as part of the policy info. Allowed values can be found in
        `policy_utilities.PolicyInfo`.
      emit_log_probability: Whether the NeuralLinUCBPolicy emits
        log-probabilities or not. Since the policy is deterministic, the
        probability is just 1.
      dtype: The type of the parameters stored and updated by the agent. Should
        be one of `tf.float32` and `tf.float64`. Defaults to `tf.float64`.
      name: a name for this instance of `NeuralLinUCBAgent`.

    Raises:
      TypeError if variable_collection is not an instance of
        `NeuralLinUCBVariableCollection`.
      ValueError if dtype is not one of `tf.float32` or `tf.float64`.
    """
        tf.Module.__init__(self, name=name)
        common.tf_agents_gauge.get_cell('TFABandit').set(True)
        self._num_actions = policy_utilities.get_num_actions_from_tensor_spec(
            action_spec)
        self._num_models = 1 if accepts_per_arm_features else self._num_actions
        self._observation_and_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
        self._accepts_per_arm_features = accepts_per_arm_features
        self._alpha = alpha
        if variable_collection is None:
            variable_collection = NeuralLinUCBVariableCollection(
                self._num_models, encoding_dim, dtype)
        elif not isinstance(variable_collection,
                            NeuralLinUCBVariableCollection):
            raise TypeError('Parameter `variable_collection` should be '
                            'of type `NeuralLinUCBVariableCollection`.')
        self._variable_collection = variable_collection
        self._gamma = gamma
        if self._gamma < 0.0 or self._gamma > 1.0:
            raise ValueError(
                'Forgetting factor `gamma` must be in [0.0, 1.0].')
        self._dtype = dtype
        if dtype not in (tf.float32, tf.float64):
            raise ValueError(
                'Agent dtype should be either `tf.float32 or `tf.float64`.')
        self._epsilon_greedy = epsilon_greedy

        reward_layer = tf.keras.layers.Dense(
            self._num_models,
            kernel_initializer=tf.random_uniform_initializer(minval=-0.03,
                                                             maxval=0.03),
            use_bias=False,
            activation=None,
            name='reward_layer')

        encoding_network.create_variables()
        self._encoding_network = encoding_network
        reward_layer.build(input_shape=tf.TensorShape([None, encoding_dim]))
        self._reward_layer = reward_layer
        self._encoding_network_num_train_steps = encoding_network_num_train_steps
        self._encoding_dim = encoding_dim
        self._optimizer = optimizer
        self._error_loss_fn = error_loss_fn
        self._gradient_clipping = gradient_clipping
        train_step_counter = tf.compat.v1.train.get_or_create_global_step()
        self._distributed_train_encoding_network = (
            distributed_train_encoding_network)

        policy = neural_linucb_policy.NeuralLinUCBPolicy(
            encoding_network=self._encoding_network,
            encoding_dim=self._encoding_dim,
            reward_layer=self._reward_layer,
            epsilon_greedy=self._epsilon_greedy,
            actions_from_reward_layer=self.actions_from_reward_layer,
            cov_matrix=self.cov_matrix,
            data_vector=self.data_vector,
            num_samples=self.num_samples,
            time_step_spec=time_step_spec,
            alpha=alpha,
            emit_policy_info=emit_policy_info,
            emit_log_probability=emit_log_probability,
            accepts_per_arm_features=accepts_per_arm_features,
            distributed_use_reward_layer=distributed_train_encoding_network,
            observation_and_action_constraint_splitter=(
                observation_and_action_constraint_splitter))

        training_data_spec = None
        if accepts_per_arm_features:
            training_data_spec = bandit_spec_utils.drop_arm_observation(
                policy.trajectory_spec)
        super(NeuralLinUCBAgent,
              self).__init__(time_step_spec=time_step_spec,
                             action_spec=policy.action_spec,
                             policy=policy,
                             collect_policy=policy,
                             train_sequence_length=None,
                             training_data_spec=training_data_spec,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)