Example #1
0
  def __init__(
      self,
      action_spec: specs.DiscreteArray,
      network: snt.Module,
      batch_size: int,
      discount: float,
      replay_capacity: int,
      min_replay_size: int,
      sgd_period: int,
      target_update_period: int,
      optimizer: snt.Optimizer,
      epsilon: float,
      seed: int = None,
  ):

    # Internalise hyperparameters.
    self._num_actions = action_spec.num_values
    self._discount = discount
    self._batch_size = batch_size
    self._sgd_period = sgd_period
    self._target_update_period = target_update_period
    self._epsilon = epsilon
    self._min_replay_size = min_replay_size

    # Seed the RNG.
    tf.random.set_seed(seed)
    self._rng = np.random.RandomState(seed)

    # Internalise the components (networks, optimizer, replay buffer).
    self._optimizer = optimizer
    self._replay = replay.Replay(capacity=replay_capacity)
    self._online_network = network
    self._target_network = copy.deepcopy(network)
    self._forward = tf.function(network)
    self._total_steps = tf.Variable(0)
    def __init__(
        self,
        action_spec: specs.DiscreteArray,
        online_network: snt.Module,
        target_network: snt.Module,
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: snt.Optimizer,
        epsilon: float,
        seed: int = None,
    ):
        # DQN configuration and hyperparameters.
        self._num_actions = action_spec.num_values
        self._discount = discount
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._optimizer = optimizer
        self._epsilon = epsilon
        self._total_steps = 0
        self._replay = replay.Replay(capacity=replay_capacity)
        self._min_replay_size = min_replay_size

        tf.random.set_seed(seed)
        self._rng = np.random.RandomState(seed)

        # Internalize the networks.
        self._online_network = online_network
        self._target_network = target_network
        self._forward = tf.function(online_network)
Example #3
0
  def __init__(
      self,
      environment_spec: specs.EnvironmentSpec,
      replay_capacity: int,
      batch_size: int,
      hidden_sizes: Tuple[int, ...],
      learning_rate: float = 1e-3,
      terminal_tol: float = 1e-3,
  ):
    self._obs_spec = environment_spec.observations
    self._action_spec = environment_spec.actions
    # Hyperparameters.
    self._batch_size = batch_size
    self._terminal_tol = terminal_tol

    # Modelling
    self._replay = replay.Replay(replay_capacity)
    self._transition_model = MLPTransitionModel(environment_spec, hidden_sizes)
    self._optimizer = snt.optimizers.Adam(learning_rate)
    self._forward = tf.function(self._transition_model)
    tf2_utils.create_variables(
        self._transition_model, [self._obs_spec, self._action_spec])
    self._variables = self._transition_model.trainable_variables

    # Model state.
    self._needs_reset = True
Example #4
0
    def test_end_to_end(self):
        shapes = (10, 10, 3), ()
        capacity = 5

        def generate_sample():
            return [
                np.random.randint(0, 256, size=(10, 10, 3), dtype=np.uint8),
                np.random.uniform(size=())
            ]

        replay = replay_lib.Replay(capacity=capacity)

        # Does it crash if we sample when there's barely any data?
        sample = generate_sample()
        replay.add(sample)
        samples = replay.sample(size=2)
        for sample, shape in zip(samples, shapes):
            self.assertEqual(sample.shape, (2, ) + shape)

        # Fill to capacity.
        for _ in range(capacity - 1):
            replay.add(generate_sample())
            samples = replay.sample(size=3)
            for sample, shape in zip(samples, shapes):
                self.assertEqual(sample.shape, (3, ) + shape)

        replay.add(generate_sample())
        samples = replay.sample(size=capacity)
        for sample, shape in zip(samples, shapes):
            self.assertEqual(sample.shape, (capacity, ) + shape)
Example #5
0
    def __init__(
        self,
        obs_spec: dm_env.specs.Array,
        action_spec: dm_env.specs.DiscreteArray,
        online_network: tf.keras.Sequential,
        target_network: tf.keras.Sequential,
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: tf.keras.optimizers.Optimizer,
        seed: int = None,
    ):
        """A simple DQN agent."""
        # tf.keras.backend.set_floatx('float32')

        # DQN configuration and hyperparameters.
        self._num_actions = action_spec.num_values
        self._discount = discount
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._optimizer = optimizer
        self._total_steps = 0
        self._total_episodes = 0
        self._replay = replay.Replay(capacity=replay_capacity)
        self._min_replay_size = min_replay_size
        tf.random.set_seed(seed)
        self._rng = np.random.RandomState(seed)
        self._epsilon_fn = lambda t: 10 / (10 + t)

        self._online_network = online_network
        self._target_network = target_network
Example #6
0
    def __init__(
        self,
        action_spec: specs.DiscreteArray,
        network: Network,
        parameters: NetworkParameters,
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        learning_rate: float,
        epsilon: float,
        seed: int = None,
    ):

        # DQN configuration and hyperparameters.
        self._num_actions = action_spec.num_values
        self._discount = discount
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._epsilon = epsilon
        self._total_steps = 0
        self._replay = replay.Replay(capacity=replay_capacity)
        self._min_replay_size = min_replay_size

        self._rng = np.random.RandomState(seed)

        def loss(online_params, target_params, transitions):
            o_tm1, a_tm1, r_t, d_t, o_t = transitions
            q_tm1 = network(online_params, o_tm1)
            q_t = network(target_params, o_t)
            q_target = r_t + d_t * discount * jnp.max(q_t, axis=-1)
            q_a_tm1 = jax.vmap(lambda q, a: q[a])(q_tm1, a_tm1)
            td_error = q_a_tm1 - lax.stop_gradient(q_target)

            return jnp.mean(td_error**2)

        # Internalize the networks.
        self._network = network
        self._parameters = parameters
        self._target_parameters = parameters

        # This function computes dL/dTheta
        self._grad = jax.jit(jax.grad(loss))
        self._forward = jax.jit(network)

        # Make an Adam optimizer.
        opt_init, opt_update, get_params = optimizers.adam(
            step_size=learning_rate)
        self._opt_update = jax.jit(opt_update)
        self._opt_state = opt_init(parameters)
        self._get_params = get_params
Example #7
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        ensemble: Sequence[snt.Module],
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: snt.Optimizer,
        mask_prob: float,
        noise_scale: float,
        epsilon_fn: Callable[[int], float] = lambda _: 0.,
        seed: int = None,
    ):
        """Bootstrapped DQN with additive prior functions."""
        # Agent components.
        self._ensemble = ensemble
        self._forward = [tf.function(net) for net in ensemble]
        self._target_ensemble = [
            copy.deepcopy(network) for network in ensemble
        ]
        self._num_ensemble = len(ensemble)
        self._optimizer = optimizer
        self._replay = replay.Replay(capacity=replay_capacity)

        # Create variables for each network in the ensemble
        for network in ensemble:
            snt.build(network, (None, *obs_spec.shape))

        # Agent hyperparameters.
        self._num_actions = action_spec.num_values
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._min_replay_size = min_replay_size
        self._epsilon_fn = epsilon_fn
        self._mask_prob = mask_prob
        self._noise_scale = noise_scale
        self._rng = np.random.RandomState(seed)
        self._discount = discount

        # Agent state.
        self._total_steps = tf.Variable(1)
        self._active_head = 0
        tf.random.set_seed(seed)
Example #8
0
    def __init__(
        self,
        obs_spec: dm_env.specs.Array,
        action_spec: dm_env.specs.BoundedArray,
        ensemble: Sequence[snt.AbstractModule],
        target_ensemble: Sequence[snt.AbstractModule],
        batch_size: int,
        agent_discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: tf.train.Optimizer,
        mask_prob: float,
        noise_scale: float,
        epsilon_fn: Callable[[int], float] = lambda _: 0.,
        seed: int = None,
    ):
        """Bootstrapped DQN with additive prior functions."""
        # Dqn configurations.
        self._ensemble = ensemble
        self._target_ensemble = target_ensemble
        self._num_actions = action_spec.maximum - action_spec.minimum + 1
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._min_replay_size = min_replay_size
        self._epsilon_fn = epsilon_fn
        self._replay = replay.Replay(capacity=replay_capacity)
        self._mask_prob = mask_prob
        self._noise_scale = noise_scale
        self._rng = np.random.RandomState(seed)
        tf.set_random_seed(seed)

        self._total_steps = 0
        self._total_episodes = 0
        self._active_head = 0
        self._num_ensemble = len(ensemble)
        assert len(ensemble) == len(target_ensemble)

        # Making the tensorflow graph
        session = tf.Session()

        # Placeholders = (obs, action, reward, discount, next_obs, mask, noise)
        o_tm1 = tf.placeholder(shape=(None, ) + obs_spec.shape,
                               dtype=obs_spec.dtype)
        a_tm1 = tf.placeholder(shape=(None, ), dtype=action_spec.dtype)
        r_t = tf.placeholder(shape=(None, ), dtype=tf.float32)
        d_t = tf.placeholder(shape=(None, ), dtype=tf.float32)
        o_t = tf.placeholder(shape=(None, ) + obs_spec.shape,
                             dtype=obs_spec.dtype)
        m_t = tf.placeholder(shape=(None, self._num_ensemble),
                             dtype=tf.float32)
        z_t = tf.placeholder(shape=(None, self._num_ensemble),
                             dtype=tf.float32)

        losses = []
        value_fns = []
        target_updates = []
        for k in range(self._num_ensemble):
            model = self._ensemble[k]
            target_model = self._target_ensemble[k]
            q_values = model(o_tm1)

            train_value = batched_index(q_values, a_tm1)
            target_value = tf.stop_gradient(
                tf.reduce_max(target_model(o_t), axis=-1))
            target_y = r_t + z_t[:, k] + agent_discount * d_t * target_value
            loss = tf.square(train_value - target_y) * m_t[:, k]

            value_fn = session.make_callable(q_values, [o_tm1])
            target_update = update_target_variables(
                target_variables=target_model.get_all_variables(),
                source_variables=model.get_all_variables(),
            )

            losses.append(loss)
            value_fns.append(value_fn)
            target_updates.append(target_update)

        sgd_op = optimizer.minimize(tf.stack(losses))
        self._value_fns = value_fns
        self._sgd_step = session.make_callable(
            sgd_op, [o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t])
        self._update_target_nets = session.make_callable(target_updates)
        session.run(tf.global_variables_initializer())
Example #9
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        online_network: snt.AbstractModule,
        target_network: snt.AbstractModule,
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: tf.train.Optimizer,
        epsilon: float,
        seed: int = None,
    ):
        """A simple DQN agent."""

        # DQN configuration and hyperparameters.
        self._num_actions = action_spec.num_values
        self._discount = discount
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._optimizer = optimizer
        self._epsilon = epsilon
        self._total_steps = 0
        self._replay = replay.Replay(capacity=replay_capacity)
        self._min_replay_size = min_replay_size
        tf.set_random_seed(seed)
        self._rng = np.random.RandomState(seed)

        # Make the TensorFlow graph.
        o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
        q = online_network(tf.expand_dims(o, 0))

        o_tm1 = tf.placeholder(shape=(None, ) + obs_spec.shape,
                               dtype=obs_spec.dtype)
        a_tm1 = tf.placeholder(shape=(None, ), dtype=action_spec.dtype)
        r_t = tf.placeholder(shape=(None, ), dtype=tf.float32)
        d_t = tf.placeholder(shape=(None, ), dtype=tf.float32)
        o_t = tf.placeholder(shape=(None, ) + obs_spec.shape,
                             dtype=obs_spec.dtype)

        q_tm1 = online_network(o_tm1)
        q_t = target_network(o_t)
        loss = qlearning(q_tm1, a_tm1, r_t, discount * d_t, q_t).loss

        train_op = self._optimizer.minimize(loss)
        with tf.control_dependencies([train_op]):
            train_op = periodic_target_update(
                target_variables=target_network.variables,
                source_variables=online_network.variables,
                update_period=target_update_period)

        # Make session and callables.
        session = tf.Session()
        self._sgd_fn = session.make_callable(train_op,
                                             [o_tm1, a_tm1, r_t, d_t, o_t])
        self._value_fn = session.make_callable(q, [o])
        session.run(tf.global_variables_initializer())
Example #10
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: hk.Transformed,
        num_ensemble: int,
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: optix.InitUpdate,
        mask_prob: float,
        noise_scale: float,
        epsilon_fn: Callable[[int], float] = lambda _: 0.,
        seed: int = 1,
    ):
        """Bootstrapped DQN with randomized prior functions."""

        # Define loss function, including bootstrap mask `m_t` & reward noise `z_t`.
        def loss(params: hk.Params, target_params: hk.Params,
                 transitions: Sequence[jnp.ndarray]) -> jnp.ndarray:
            """Q-learning loss with added reward noise + half-in bootstrap."""
            o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
            q_tm1 = network.apply(params, o_tm1)
            q_t = network.apply(target_params, o_t)
            r_t += noise_scale * z_t
            batch_q_learning = jax.vmap(rlax.q_learning)
            td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t)
            return jnp.mean(m_t * td_error**2)

        # Define update function for each member of ensemble..
        @jax.jit
        def sgd_step(state: TrainingState,
                     transitions: Sequence[jnp.ndarray]) -> TrainingState:
            """Does a step of SGD for the whole ensemble over `transitions`."""

            gradients = jax.grad(loss)(state.params, state.target_params,
                                       transitions)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            return TrainingState(params=new_params,
                                 target_params=state.target_params,
                                 opt_state=new_opt_state,
                                 step=state.step + 1)

        # Initialize parameters and optimizer state for an ensemble of Q-networks.
        rng = hk.PRNGSequence(seed)
        dummy_obs = np.zeros((1, *obs_spec.shape), jnp.float32)
        initial_params = [
            network.init(next(rng), dummy_obs) for _ in range(num_ensemble)
        ]
        initial_target_params = [
            network.init(next(rng), dummy_obs) for _ in range(num_ensemble)
        ]
        initial_opt_state = [optimizer.init(p) for p in initial_params]

        # Internalize state.
        self._ensemble = [
            TrainingState(p, tp, o, step=0) for p, tp, o in zip(
                initial_params, initial_target_params, initial_opt_state)
        ]
        self._forward = jax.jit(network.apply)
        self._sgd_step = sgd_step
        self._num_ensemble = num_ensemble
        self._optimizer = optimizer
        self._replay = replay.Replay(capacity=replay_capacity)

        # Agent hyperparameters.
        self._num_actions = action_spec.num_values
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._min_replay_size = min_replay_size
        self._epsilon_fn = epsilon_fn
        self._mask_prob = mask_prob

        # Agent state.
        self._active_head = self._ensemble[0]
        self._total_steps = 0
Example #11
0
    def __init__(
        self,
        obs_spec: specs.Array,
        action_spec: specs.DiscreteArray,
        network: hk.Transformed,
        optimizer: optix.InitUpdate,
        batch_size: int,
        epsilon: float,
        rng: hk.PRNGSequence,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
    ):

        # Define loss function.
        def loss(params: hk.Params, target_params: hk.Params,
                 transitions: Sequence[jnp.ndarray]) -> jnp.ndarray:
            """Computes the standard TD(0) Q-learning loss on batch of transitions."""
            o_tm1, a_tm1, r_t, d_t, o_t = transitions
            q_tm1 = network.apply(params, o_tm1)
            q_t = network.apply(target_params, o_t)
            batch_q_learning = jax.vmap(rlax.q_learning)
            td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t)
            return jnp.mean(td_error**2)

        # Define update function.
        @jax.jit
        def sgd_step(state: TrainingState,
                     transitions: Sequence[jnp.ndarray]) -> TrainingState:
            """Performs an SGD step on a batch of transitions."""
            gradients = jax.grad(loss)(state.params, state.target_params,
                                       transitions)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            return TrainingState(params=new_params,
                                 target_params=state.target_params,
                                 opt_state=new_opt_state,
                                 step=state.step + 1)

        # Initialize the networks and optimizer.
        dummy_observation = np.zeros((1, *obs_spec.shape), jnp.float32)
        initial_params = network.init(next(rng), dummy_observation)
        initial_target_params = network.init(next(rng), dummy_observation)
        initial_opt_state = optimizer.init(initial_params)

        # This carries the agent state relevant to training.
        self._state = TrainingState(params=initial_params,
                                    target_params=initial_target_params,
                                    opt_state=initial_opt_state,
                                    step=0)
        self._sgd_step = sgd_step
        self._forward = jax.jit(network.apply)
        self._replay = replay.Replay(capacity=replay_capacity)

        # Store hyperparameters.
        self._num_actions = action_spec.num_values
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._epsilon = epsilon
        self._total_steps = 0
        self._min_replay_size = min_replay_size
Example #12
0
    def __init__(
        self,
        obs_spec: dm_env.specs.Array,
        action_spec: dm_env.specs.DiscreteArray,
        online_network: tf.keras.Sequential,
        target_network: tf.keras.Sequential,
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer: tf.keras.optimizers.Optimizer,
        posterior_optimizer: tf.optimizers.Optimizer = tf.optimizers.SGD(
            learning_rate=1e-2),
        seed: int = None,
    ):
        """A simple DQN agent."""
        # DQN configuration and hyperparameters.
        self._num_features = 32
        self._num_actions = action_spec.num_values
        self._discount = discount
        self._batch_size = batch_size
        self._optimizer = optimizer
        self._posterior_optimizer = posterior_optimizer
        self._total_steps = 0
        self._total_episodes = 0
        self._replay = replay.Replay(capacity=replay_capacity)
        self._min_replay_size = min_replay_size

        #time periods for updating
        self._sgd_period = sgd_period  #the paper 4
        self._target_update_period = target_update_period  #the paper 10000 (10k)
        self._posterior_update_period = 500  #the paper 100000 (100k)
        self._sample_out_mus_period = 20  #the paper 1000 (1k)

        #neural network for the features
        self._online_network = online_network
        self._target_network = target_network

        # normal output distribution
        self._target_mus = []
        self._target_mu_covs = []
        self._normal_distros = []
        self._out_mus = []
        eye = tf.eye(self._num_features)
        bijector = tfp.bijectors.FillTriangular(upper=False)
        for idx in range(self._num_actions):
            mu = tf.random.normal([self._num_features
                                   ])  #needs size: num_features
            cov = tf.random.normal(
                [self._num_features, self._num_features],
                stddev=0.1) + eye  #needs size: num_features x num_features
            cov = tf.linalg.band_part(cov, 0, -1)  #upper triangular
            cov = 0.5 * (cov + tf.transpose(cov))  #make it symmetric
            chol = tf.linalg.cholesky(cov)
            chol = bijector.inverse(chol)
            mu_cov = tf.concat([mu, chol], 0)
            mu_cov = tf.Variable(mu_cov)
            #
            normal_distro = tfp.layers.MultivariateNormalTriL(
                self._num_features)
            self._target_mus.append(normal_distro(mu_cov).mean())
            self._target_mu_covs.append(mu_cov)
            self._normal_distros.append(normal_distro)
            self._out_mus.append(normal_distro(mu_cov).sample())

        # setting unified keras backend
        tf.keras.backend.set_floatx('float32')
Example #13
0
    def __init__(
        self,
        obs_spec: dm_env.specs.Array,
        action_spec: dm_env.specs.BoundedArray,
        q_network: snt.AbstractModule,
        target_q_network: snt.AbstractModule,
        rho_network: snt.AbstractModule,
        l_network: Sequence[snt.AbstractModule],
        target_l_network: Sequence[snt.AbstractModule],
        batch_size: int,
        discount: float,
        replay_capacity: int,
        min_replay_size: int,
        sgd_period: int,
        target_update_period: int,
        optimizer_primal: tf.train.Optimizer,
        optimizer_dual: tf.train.Optimizer,
        optimizer_l: tf.train.Optimizer,
        learn_iters: int,
        l_approximators: int,
        min_l: float,
        kappa: float,
        eta1: float,
        eta2: float,
        seed: int = None,
    ):
        """Information seeking learner."""
        # ISL configurations.
        self.q_network = q_network
        self._target_q_network = target_q_network
        self.rho_network = rho_network
        self.l_network = l_network
        self._target_l_network = target_l_network
        self._num_actions = action_spec.maximum - action_spec.minimum + 1
        self._obs_shape = obs_spec.shape
        self._batch_size = batch_size
        self._sgd_period = sgd_period
        self._target_update_period = target_update_period
        self._optimizer_primal = optimizer_primal
        self._optimizer_dual = optimizer_dual
        self._optimizer_l = optimizer_l
        self._min_replay_size = min_replay_size
        self._replay = replay.Replay(
            capacity=replay_capacity
        )  #ISLReplay(capacity=replay_capacity, average_l=0, mu=0)  #
        self._rng = np.random.RandomState(seed)
        tf.set_random_seed(seed)
        self._kappa = kappa
        self._min_l = min_l
        self._eta1 = eta1
        self._eta2 = eta2
        self._learn_iters = learn_iters
        self._l_approximators = l_approximators
        self._total_steps = 0
        self._total_episodes = 0
        self._learn_iter_counter = 0

        # Making the tensorflow graph
        o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
        q = q_network(tf.expand_dims(o, 0))
        rho = rho_network(tf.expand_dims(o, 0))
        l = []
        for k in range(self._l_approximators):
            l.append(
                tf.concat([
                    l_network[k][a](tf.expand_dims(o, 0))
                    for a in range(self._num_actions)
                ],
                          axis=1))

        # Placeholders = (obs, action, reward, discount, next_obs)
        o_tm1 = tf.placeholder(shape=(None, ) + obs_spec.shape,
                               dtype=obs_spec.dtype)
        a_tm1 = tf.placeholder(shape=(None, ), dtype=action_spec.dtype)
        r_t = tf.placeholder(shape=(None, ), dtype=tf.float32)
        d_t = tf.placeholder(shape=(None, ), dtype=tf.float32)
        o_t = tf.placeholder(shape=(None, ) + obs_spec.shape,
                             dtype=obs_spec.dtype)
        chosen_l = tf.placeholder(shape=1,
                                  dtype=tf.int32,
                                  name='chosen_l_tensor')

        q_tm1 = q_network(o_tm1)
        rho_tm1 = rho_network(o_tm1)
        train_q_value = batched_index(q_tm1, a_tm1)
        train_rho_value = batched_index(rho_tm1, a_tm1)
        train_rho_value_no_grad = tf.stop_gradient(train_rho_value)
        if self._target_update_period > 1:
            q_t = target_q_network(o_t)
        else:
            q_t = q_network(o_t)

        l_tm1_all = tf.stack([
            tf.concat([
                self.l_network[k][a](o_tm1) for a in range(self._num_actions)
            ],
                      axis=1) for k in range(self._l_approximators)
        ],
                             axis=-1)
        l_tm1 = tf.squeeze(tf.gather(l_tm1_all, chosen_l, axis=-1), axis=-1)
        train_l_value = batched_index(l_tm1, a_tm1)

        if self._target_update_period > 1:
            l_online_t_all = tf.stack([
                tf.concat([
                    self.l_network[k][a](o_t) for a in range(self._num_actions)
                ],
                          axis=1) for k in range(self._l_approximators)
            ],
                                      axis=-1)
            l_online_t = tf.squeeze(tf.gather(l_online_t_all,
                                              chosen_l,
                                              axis=-1),
                                    axis=-1)
            l_t_all = tf.stack([
                tf.concat([
                    self._target_l_network[k][a](o_t)
                    for a in range(self._num_actions)
                ],
                          axis=1) for k in range(self._l_approximators)
            ],
                               axis=-1)
            l_t = tf.squeeze(tf.gather(l_t_all, chosen_l, axis=-1), axis=-1)
            max_ind = tf.math.argmax(l_online_t, axis=1)
        else:
            l_t_all = tf.stack([
                tf.concat([
                    self.l_network[k][a](o_t) for a in range(self._num_actions)
                ],
                          axis=1) for k in range(self._l_approximators)
            ],
                               axis=-1)
            l_t = tf.squeeze(tf.gather(l_t_all, chosen_l, axis=-1), axis=-1)
            max_ind = tf.math.argmax(l_t, axis=1)

        soft_max_value = tf.stop_gradient(
            tf.py_function(func=self.soft_max, inp=[q_t, l_t],
                           Tout=tf.float32))
        q_target_value = r_t + discount * d_t * soft_max_value
        delta_primal = train_q_value - q_target_value
        loss_primal = tf.add(eta2 * train_rho_value_no_grad * delta_primal,
                             (1 - eta2) * 0.5 * tf.square(delta_primal),
                             name='loss_q')

        delta_dual = tf.stop_gradient(delta_primal)
        loss_dual = tf.square(delta_dual - train_rho_value, name='loss_rho')

        l_greedy_estimate = tf.add((1 - eta1) * tf.math.abs(delta_primal),
                                   eta1 * tf.math.abs(train_rho_value_no_grad),
                                   name='l_greedy_estimate')
        l_target_value = tf.stop_gradient(
            l_greedy_estimate + discount * d_t * batched_index(l_t, max_ind),
            name='l_target')
        loss_l = 0.5 * tf.square(train_l_value - l_target_value)

        train_op_primal = self._optimizer_primal.minimize(loss_primal)
        train_op_dual = self._optimizer_dual.minimize(loss_dual)
        train_op_l = self._optimizer_l.minimize(loss_l)

        # create target update operations
        if self._target_update_period > 1:
            target_updates = []
            target_update = update_target_variables(
                target_variables=self._target_q_network.get_all_variables(),
                source_variables=self.q_network.get_all_variables(),
            )
            target_updates.append(target_update)
            for k in range(self._l_approximators):
                for a in range(self._num_actions):
                    model = self.l_network[k][a]
                    target_model = self._target_l_network[k][a]
                    target_update = update_target_variables(
                        target_variables=target_model.get_all_variables(),
                        source_variables=model.get_all_variables(),
                    )
                    target_updates.append(target_update)

        # Make session and callables.
        session = tf.Session()
        self._sgd = session.make_callable(
            [train_op_l, train_op_primal, train_op_dual],
            [o_tm1, a_tm1, r_t, d_t, o_t, chosen_l])
        self._q_fn = session.make_callable(q, [o])
        self._rho_fn = session.make_callable(rho, [o])
        self._l_fn = []
        for k in range(self._l_approximators):
            self._l_fn.append(session.make_callable(l[k], [o]))
        if self._target_update_period > 1:
            self._update_target_nets = session.make_callable(target_updates)
        session.run(tf.global_variables_initializer())