def __init__( self, action_spec: specs.DiscreteArray, network: snt.Module, batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer: snt.Optimizer, epsilon: float, seed: int = None, ): # Internalise hyperparameters. self._num_actions = action_spec.num_values self._discount = discount self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._epsilon = epsilon self._min_replay_size = min_replay_size # Seed the RNG. tf.random.set_seed(seed) self._rng = np.random.RandomState(seed) # Internalise the components (networks, optimizer, replay buffer). self._optimizer = optimizer self._replay = replay.Replay(capacity=replay_capacity) self._online_network = network self._target_network = copy.deepcopy(network) self._forward = tf.function(network) self._total_steps = tf.Variable(0)
def __init__( self, action_spec: specs.DiscreteArray, online_network: snt.Module, target_network: snt.Module, batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer: snt.Optimizer, epsilon: float, seed: int = None, ): # DQN configuration and hyperparameters. self._num_actions = action_spec.num_values self._discount = discount self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._optimizer = optimizer self._epsilon = epsilon self._total_steps = 0 self._replay = replay.Replay(capacity=replay_capacity) self._min_replay_size = min_replay_size tf.random.set_seed(seed) self._rng = np.random.RandomState(seed) # Internalize the networks. self._online_network = online_network self._target_network = target_network self._forward = tf.function(online_network)
def __init__( self, environment_spec: specs.EnvironmentSpec, replay_capacity: int, batch_size: int, hidden_sizes: Tuple[int, ...], learning_rate: float = 1e-3, terminal_tol: float = 1e-3, ): self._obs_spec = environment_spec.observations self._action_spec = environment_spec.actions # Hyperparameters. self._batch_size = batch_size self._terminal_tol = terminal_tol # Modelling self._replay = replay.Replay(replay_capacity) self._transition_model = MLPTransitionModel(environment_spec, hidden_sizes) self._optimizer = snt.optimizers.Adam(learning_rate) self._forward = tf.function(self._transition_model) tf2_utils.create_variables( self._transition_model, [self._obs_spec, self._action_spec]) self._variables = self._transition_model.trainable_variables # Model state. self._needs_reset = True
def test_end_to_end(self): shapes = (10, 10, 3), () capacity = 5 def generate_sample(): return [ np.random.randint(0, 256, size=(10, 10, 3), dtype=np.uint8), np.random.uniform(size=()) ] replay = replay_lib.Replay(capacity=capacity) # Does it crash if we sample when there's barely any data? sample = generate_sample() replay.add(sample) samples = replay.sample(size=2) for sample, shape in zip(samples, shapes): self.assertEqual(sample.shape, (2, ) + shape) # Fill to capacity. for _ in range(capacity - 1): replay.add(generate_sample()) samples = replay.sample(size=3) for sample, shape in zip(samples, shapes): self.assertEqual(sample.shape, (3, ) + shape) replay.add(generate_sample()) samples = replay.sample(size=capacity) for sample, shape in zip(samples, shapes): self.assertEqual(sample.shape, (capacity, ) + shape)
def __init__( self, obs_spec: dm_env.specs.Array, action_spec: dm_env.specs.DiscreteArray, online_network: tf.keras.Sequential, target_network: tf.keras.Sequential, batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer: tf.keras.optimizers.Optimizer, seed: int = None, ): """A simple DQN agent.""" # tf.keras.backend.set_floatx('float32') # DQN configuration and hyperparameters. self._num_actions = action_spec.num_values self._discount = discount self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._optimizer = optimizer self._total_steps = 0 self._total_episodes = 0 self._replay = replay.Replay(capacity=replay_capacity) self._min_replay_size = min_replay_size tf.random.set_seed(seed) self._rng = np.random.RandomState(seed) self._epsilon_fn = lambda t: 10 / (10 + t) self._online_network = online_network self._target_network = target_network
def __init__( self, action_spec: specs.DiscreteArray, network: Network, parameters: NetworkParameters, batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, learning_rate: float, epsilon: float, seed: int = None, ): # DQN configuration and hyperparameters. self._num_actions = action_spec.num_values self._discount = discount self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._epsilon = epsilon self._total_steps = 0 self._replay = replay.Replay(capacity=replay_capacity) self._min_replay_size = min_replay_size self._rng = np.random.RandomState(seed) def loss(online_params, target_params, transitions): o_tm1, a_tm1, r_t, d_t, o_t = transitions q_tm1 = network(online_params, o_tm1) q_t = network(target_params, o_t) q_target = r_t + d_t * discount * jnp.max(q_t, axis=-1) q_a_tm1 = jax.vmap(lambda q, a: q[a])(q_tm1, a_tm1) td_error = q_a_tm1 - lax.stop_gradient(q_target) return jnp.mean(td_error**2) # Internalize the networks. self._network = network self._parameters = parameters self._target_parameters = parameters # This function computes dL/dTheta self._grad = jax.jit(jax.grad(loss)) self._forward = jax.jit(network) # Make an Adam optimizer. opt_init, opt_update, get_params = optimizers.adam( step_size=learning_rate) self._opt_update = jax.jit(opt_update) self._opt_state = opt_init(parameters) self._get_params = get_params
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, ensemble: Sequence[snt.Module], batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer: snt.Optimizer, mask_prob: float, noise_scale: float, epsilon_fn: Callable[[int], float] = lambda _: 0., seed: int = None, ): """Bootstrapped DQN with additive prior functions.""" # Agent components. self._ensemble = ensemble self._forward = [tf.function(net) for net in ensemble] self._target_ensemble = [ copy.deepcopy(network) for network in ensemble ] self._num_ensemble = len(ensemble) self._optimizer = optimizer self._replay = replay.Replay(capacity=replay_capacity) # Create variables for each network in the ensemble for network in ensemble: snt.build(network, (None, *obs_spec.shape)) # Agent hyperparameters. self._num_actions = action_spec.num_values self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._min_replay_size = min_replay_size self._epsilon_fn = epsilon_fn self._mask_prob = mask_prob self._noise_scale = noise_scale self._rng = np.random.RandomState(seed) self._discount = discount # Agent state. self._total_steps = tf.Variable(1) self._active_head = 0 tf.random.set_seed(seed)
def __init__( self, obs_spec: dm_env.specs.Array, action_spec: dm_env.specs.BoundedArray, ensemble: Sequence[snt.AbstractModule], target_ensemble: Sequence[snt.AbstractModule], batch_size: int, agent_discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer: tf.train.Optimizer, mask_prob: float, noise_scale: float, epsilon_fn: Callable[[int], float] = lambda _: 0., seed: int = None, ): """Bootstrapped DQN with additive prior functions.""" # Dqn configurations. self._ensemble = ensemble self._target_ensemble = target_ensemble self._num_actions = action_spec.maximum - action_spec.minimum + 1 self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._min_replay_size = min_replay_size self._epsilon_fn = epsilon_fn self._replay = replay.Replay(capacity=replay_capacity) self._mask_prob = mask_prob self._noise_scale = noise_scale self._rng = np.random.RandomState(seed) tf.set_random_seed(seed) self._total_steps = 0 self._total_episodes = 0 self._active_head = 0 self._num_ensemble = len(ensemble) assert len(ensemble) == len(target_ensemble) # Making the tensorflow graph session = tf.Session() # Placeholders = (obs, action, reward, discount, next_obs, mask, noise) o_tm1 = tf.placeholder(shape=(None, ) + obs_spec.shape, dtype=obs_spec.dtype) a_tm1 = tf.placeholder(shape=(None, ), dtype=action_spec.dtype) r_t = tf.placeholder(shape=(None, ), dtype=tf.float32) d_t = tf.placeholder(shape=(None, ), dtype=tf.float32) o_t = tf.placeholder(shape=(None, ) + obs_spec.shape, dtype=obs_spec.dtype) m_t = tf.placeholder(shape=(None, self._num_ensemble), dtype=tf.float32) z_t = tf.placeholder(shape=(None, self._num_ensemble), dtype=tf.float32) losses = [] value_fns = [] target_updates = [] for k in range(self._num_ensemble): model = self._ensemble[k] target_model = self._target_ensemble[k] q_values = model(o_tm1) train_value = batched_index(q_values, a_tm1) target_value = tf.stop_gradient( tf.reduce_max(target_model(o_t), axis=-1)) target_y = r_t + z_t[:, k] + agent_discount * d_t * target_value loss = tf.square(train_value - target_y) * m_t[:, k] value_fn = session.make_callable(q_values, [o_tm1]) target_update = update_target_variables( target_variables=target_model.get_all_variables(), source_variables=model.get_all_variables(), ) losses.append(loss) value_fns.append(value_fn) target_updates.append(target_update) sgd_op = optimizer.minimize(tf.stack(losses)) self._value_fns = value_fns self._sgd_step = session.make_callable( sgd_op, [o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t]) self._update_target_nets = session.make_callable(target_updates) session.run(tf.global_variables_initializer())
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, online_network: snt.AbstractModule, target_network: snt.AbstractModule, batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer: tf.train.Optimizer, epsilon: float, seed: int = None, ): """A simple DQN agent.""" # DQN configuration and hyperparameters. self._num_actions = action_spec.num_values self._discount = discount self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._optimizer = optimizer self._epsilon = epsilon self._total_steps = 0 self._replay = replay.Replay(capacity=replay_capacity) self._min_replay_size = min_replay_size tf.set_random_seed(seed) self._rng = np.random.RandomState(seed) # Make the TensorFlow graph. o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype) q = online_network(tf.expand_dims(o, 0)) o_tm1 = tf.placeholder(shape=(None, ) + obs_spec.shape, dtype=obs_spec.dtype) a_tm1 = tf.placeholder(shape=(None, ), dtype=action_spec.dtype) r_t = tf.placeholder(shape=(None, ), dtype=tf.float32) d_t = tf.placeholder(shape=(None, ), dtype=tf.float32) o_t = tf.placeholder(shape=(None, ) + obs_spec.shape, dtype=obs_spec.dtype) q_tm1 = online_network(o_tm1) q_t = target_network(o_t) loss = qlearning(q_tm1, a_tm1, r_t, discount * d_t, q_t).loss train_op = self._optimizer.minimize(loss) with tf.control_dependencies([train_op]): train_op = periodic_target_update( target_variables=target_network.variables, source_variables=online_network.variables, update_period=target_update_period) # Make session and callables. session = tf.Session() self._sgd_fn = session.make_callable(train_op, [o_tm1, a_tm1, r_t, d_t, o_t]) self._value_fn = session.make_callable(q, [o]) session.run(tf.global_variables_initializer())
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: hk.Transformed, num_ensemble: int, batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer: optix.InitUpdate, mask_prob: float, noise_scale: float, epsilon_fn: Callable[[int], float] = lambda _: 0., seed: int = 1, ): """Bootstrapped DQN with randomized prior functions.""" # Define loss function, including bootstrap mask `m_t` & reward noise `z_t`. def loss(params: hk.Params, target_params: hk.Params, transitions: Sequence[jnp.ndarray]) -> jnp.ndarray: """Q-learning loss with added reward noise + half-in bootstrap.""" o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions q_tm1 = network.apply(params, o_tm1) q_t = network.apply(target_params, o_t) r_t += noise_scale * z_t batch_q_learning = jax.vmap(rlax.q_learning) td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t) return jnp.mean(m_t * td_error**2) # Define update function for each member of ensemble.. @jax.jit def sgd_step(state: TrainingState, transitions: Sequence[jnp.ndarray]) -> TrainingState: """Does a step of SGD for the whole ensemble over `transitions`.""" gradients = jax.grad(loss)(state.params, state.target_params, transitions) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) # Initialize parameters and optimizer state for an ensemble of Q-networks. rng = hk.PRNGSequence(seed) dummy_obs = np.zeros((1, *obs_spec.shape), jnp.float32) initial_params = [ network.init(next(rng), dummy_obs) for _ in range(num_ensemble) ] initial_target_params = [ network.init(next(rng), dummy_obs) for _ in range(num_ensemble) ] initial_opt_state = [optimizer.init(p) for p in initial_params] # Internalize state. self._ensemble = [ TrainingState(p, tp, o, step=0) for p, tp, o in zip( initial_params, initial_target_params, initial_opt_state) ] self._forward = jax.jit(network.apply) self._sgd_step = sgd_step self._num_ensemble = num_ensemble self._optimizer = optimizer self._replay = replay.Replay(capacity=replay_capacity) # Agent hyperparameters. self._num_actions = action_spec.num_values self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._min_replay_size = min_replay_size self._epsilon_fn = epsilon_fn self._mask_prob = mask_prob # Agent state. self._active_head = self._ensemble[0] self._total_steps = 0
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: hk.Transformed, optimizer: optix.InitUpdate, batch_size: int, epsilon: float, rng: hk.PRNGSequence, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, ): # Define loss function. def loss(params: hk.Params, target_params: hk.Params, transitions: Sequence[jnp.ndarray]) -> jnp.ndarray: """Computes the standard TD(0) Q-learning loss on batch of transitions.""" o_tm1, a_tm1, r_t, d_t, o_t = transitions q_tm1 = network.apply(params, o_tm1) q_t = network.apply(target_params, o_t) batch_q_learning = jax.vmap(rlax.q_learning) td_error = batch_q_learning(q_tm1, a_tm1, r_t, discount * d_t, q_t) return jnp.mean(td_error**2) # Define update function. @jax.jit def sgd_step(state: TrainingState, transitions: Sequence[jnp.ndarray]) -> TrainingState: """Performs an SGD step on a batch of transitions.""" gradients = jax.grad(loss)(state.params, state.target_params, transitions) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) return TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) # Initialize the networks and optimizer. dummy_observation = np.zeros((1, *obs_spec.shape), jnp.float32) initial_params = network.init(next(rng), dummy_observation) initial_target_params = network.init(next(rng), dummy_observation) initial_opt_state = optimizer.init(initial_params) # This carries the agent state relevant to training. self._state = TrainingState(params=initial_params, target_params=initial_target_params, opt_state=initial_opt_state, step=0) self._sgd_step = sgd_step self._forward = jax.jit(network.apply) self._replay = replay.Replay(capacity=replay_capacity) # Store hyperparameters. self._num_actions = action_spec.num_values self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._epsilon = epsilon self._total_steps = 0 self._min_replay_size = min_replay_size
def __init__( self, obs_spec: dm_env.specs.Array, action_spec: dm_env.specs.DiscreteArray, online_network: tf.keras.Sequential, target_network: tf.keras.Sequential, batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer: tf.keras.optimizers.Optimizer, posterior_optimizer: tf.optimizers.Optimizer = tf.optimizers.SGD( learning_rate=1e-2), seed: int = None, ): """A simple DQN agent.""" # DQN configuration and hyperparameters. self._num_features = 32 self._num_actions = action_spec.num_values self._discount = discount self._batch_size = batch_size self._optimizer = optimizer self._posterior_optimizer = posterior_optimizer self._total_steps = 0 self._total_episodes = 0 self._replay = replay.Replay(capacity=replay_capacity) self._min_replay_size = min_replay_size #time periods for updating self._sgd_period = sgd_period #the paper 4 self._target_update_period = target_update_period #the paper 10000 (10k) self._posterior_update_period = 500 #the paper 100000 (100k) self._sample_out_mus_period = 20 #the paper 1000 (1k) #neural network for the features self._online_network = online_network self._target_network = target_network # normal output distribution self._target_mus = [] self._target_mu_covs = [] self._normal_distros = [] self._out_mus = [] eye = tf.eye(self._num_features) bijector = tfp.bijectors.FillTriangular(upper=False) for idx in range(self._num_actions): mu = tf.random.normal([self._num_features ]) #needs size: num_features cov = tf.random.normal( [self._num_features, self._num_features], stddev=0.1) + eye #needs size: num_features x num_features cov = tf.linalg.band_part(cov, 0, -1) #upper triangular cov = 0.5 * (cov + tf.transpose(cov)) #make it symmetric chol = tf.linalg.cholesky(cov) chol = bijector.inverse(chol) mu_cov = tf.concat([mu, chol], 0) mu_cov = tf.Variable(mu_cov) # normal_distro = tfp.layers.MultivariateNormalTriL( self._num_features) self._target_mus.append(normal_distro(mu_cov).mean()) self._target_mu_covs.append(mu_cov) self._normal_distros.append(normal_distro) self._out_mus.append(normal_distro(mu_cov).sample()) # setting unified keras backend tf.keras.backend.set_floatx('float32')
def __init__( self, obs_spec: dm_env.specs.Array, action_spec: dm_env.specs.BoundedArray, q_network: snt.AbstractModule, target_q_network: snt.AbstractModule, rho_network: snt.AbstractModule, l_network: Sequence[snt.AbstractModule], target_l_network: Sequence[snt.AbstractModule], batch_size: int, discount: float, replay_capacity: int, min_replay_size: int, sgd_period: int, target_update_period: int, optimizer_primal: tf.train.Optimizer, optimizer_dual: tf.train.Optimizer, optimizer_l: tf.train.Optimizer, learn_iters: int, l_approximators: int, min_l: float, kappa: float, eta1: float, eta2: float, seed: int = None, ): """Information seeking learner.""" # ISL configurations. self.q_network = q_network self._target_q_network = target_q_network self.rho_network = rho_network self.l_network = l_network self._target_l_network = target_l_network self._num_actions = action_spec.maximum - action_spec.minimum + 1 self._obs_shape = obs_spec.shape self._batch_size = batch_size self._sgd_period = sgd_period self._target_update_period = target_update_period self._optimizer_primal = optimizer_primal self._optimizer_dual = optimizer_dual self._optimizer_l = optimizer_l self._min_replay_size = min_replay_size self._replay = replay.Replay( capacity=replay_capacity ) #ISLReplay(capacity=replay_capacity, average_l=0, mu=0) # self._rng = np.random.RandomState(seed) tf.set_random_seed(seed) self._kappa = kappa self._min_l = min_l self._eta1 = eta1 self._eta2 = eta2 self._learn_iters = learn_iters self._l_approximators = l_approximators self._total_steps = 0 self._total_episodes = 0 self._learn_iter_counter = 0 # Making the tensorflow graph o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype) q = q_network(tf.expand_dims(o, 0)) rho = rho_network(tf.expand_dims(o, 0)) l = [] for k in range(self._l_approximators): l.append( tf.concat([ l_network[k][a](tf.expand_dims(o, 0)) for a in range(self._num_actions) ], axis=1)) # Placeholders = (obs, action, reward, discount, next_obs) o_tm1 = tf.placeholder(shape=(None, ) + obs_spec.shape, dtype=obs_spec.dtype) a_tm1 = tf.placeholder(shape=(None, ), dtype=action_spec.dtype) r_t = tf.placeholder(shape=(None, ), dtype=tf.float32) d_t = tf.placeholder(shape=(None, ), dtype=tf.float32) o_t = tf.placeholder(shape=(None, ) + obs_spec.shape, dtype=obs_spec.dtype) chosen_l = tf.placeholder(shape=1, dtype=tf.int32, name='chosen_l_tensor') q_tm1 = q_network(o_tm1) rho_tm1 = rho_network(o_tm1) train_q_value = batched_index(q_tm1, a_tm1) train_rho_value = batched_index(rho_tm1, a_tm1) train_rho_value_no_grad = tf.stop_gradient(train_rho_value) if self._target_update_period > 1: q_t = target_q_network(o_t) else: q_t = q_network(o_t) l_tm1_all = tf.stack([ tf.concat([ self.l_network[k][a](o_tm1) for a in range(self._num_actions) ], axis=1) for k in range(self._l_approximators) ], axis=-1) l_tm1 = tf.squeeze(tf.gather(l_tm1_all, chosen_l, axis=-1), axis=-1) train_l_value = batched_index(l_tm1, a_tm1) if self._target_update_period > 1: l_online_t_all = tf.stack([ tf.concat([ self.l_network[k][a](o_t) for a in range(self._num_actions) ], axis=1) for k in range(self._l_approximators) ], axis=-1) l_online_t = tf.squeeze(tf.gather(l_online_t_all, chosen_l, axis=-1), axis=-1) l_t_all = tf.stack([ tf.concat([ self._target_l_network[k][a](o_t) for a in range(self._num_actions) ], axis=1) for k in range(self._l_approximators) ], axis=-1) l_t = tf.squeeze(tf.gather(l_t_all, chosen_l, axis=-1), axis=-1) max_ind = tf.math.argmax(l_online_t, axis=1) else: l_t_all = tf.stack([ tf.concat([ self.l_network[k][a](o_t) for a in range(self._num_actions) ], axis=1) for k in range(self._l_approximators) ], axis=-1) l_t = tf.squeeze(tf.gather(l_t_all, chosen_l, axis=-1), axis=-1) max_ind = tf.math.argmax(l_t, axis=1) soft_max_value = tf.stop_gradient( tf.py_function(func=self.soft_max, inp=[q_t, l_t], Tout=tf.float32)) q_target_value = r_t + discount * d_t * soft_max_value delta_primal = train_q_value - q_target_value loss_primal = tf.add(eta2 * train_rho_value_no_grad * delta_primal, (1 - eta2) * 0.5 * tf.square(delta_primal), name='loss_q') delta_dual = tf.stop_gradient(delta_primal) loss_dual = tf.square(delta_dual - train_rho_value, name='loss_rho') l_greedy_estimate = tf.add((1 - eta1) * tf.math.abs(delta_primal), eta1 * tf.math.abs(train_rho_value_no_grad), name='l_greedy_estimate') l_target_value = tf.stop_gradient( l_greedy_estimate + discount * d_t * batched_index(l_t, max_ind), name='l_target') loss_l = 0.5 * tf.square(train_l_value - l_target_value) train_op_primal = self._optimizer_primal.minimize(loss_primal) train_op_dual = self._optimizer_dual.minimize(loss_dual) train_op_l = self._optimizer_l.minimize(loss_l) # create target update operations if self._target_update_period > 1: target_updates = [] target_update = update_target_variables( target_variables=self._target_q_network.get_all_variables(), source_variables=self.q_network.get_all_variables(), ) target_updates.append(target_update) for k in range(self._l_approximators): for a in range(self._num_actions): model = self.l_network[k][a] target_model = self._target_l_network[k][a] target_update = update_target_variables( target_variables=target_model.get_all_variables(), source_variables=model.get_all_variables(), ) target_updates.append(target_update) # Make session and callables. session = tf.Session() self._sgd = session.make_callable( [train_op_l, train_op_primal, train_op_dual], [o_tm1, a_tm1, r_t, d_t, o_t, chosen_l]) self._q_fn = session.make_callable(q, [o]) self._rho_fn = session.make_callable(rho, [o]) self._l_fn = [] for k in range(self._l_approximators): self._l_fn.append(session.make_callable(l[k], [o])) if self._target_update_period > 1: self._update_target_nets = session.make_callable(target_updates) session.run(tf.global_variables_initializer())