def __init__( self, obs_spec: dm_env.specs.Array, network: snt.RNNCore, optimizer: snt.Optimizer, sequence_length: int, td_lambda: float, discount: float, seed: int, ): """A recurrent actor-critic agent.""" # Internalise network and optimizer. self._forward = tf.function(network) self._network = network self._optimizer = optimizer # Initialise recurrent state. self._state: snt.LSTMState = network.initial_state(1) self._rollout_initial_state: snt.LSTMState = network.initial_state(1) # Set seed and internalise hyperparameters. tf.random.set_seed(seed) self._sequence_length = sequence_length self._num_transitions_in_buffer = 0 self._discount = discount self._td_lambda = td_lambda # Initialise rolling experience buffer. shapes = [obs_spec.shape, (), (), (), ()] dtypes = [obs_spec.dtype, np.int32, np.float32, np.float32, np.float32] self._buffer = [ np.zeros(shape=(self._sequence_length, 1) + shape, dtype=dtype) for shape, dtype in zip(shapes, dtypes) ]
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, queue: adder.Adder, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, n_step_horizon: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, max_abs_reward: Optional[float] = None, max_gradient_norm: Optional[float] = None, verbose_level: Optional[int] = 0, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') extra_spec = { 'core_state': network.initial_state(1), 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) tf2_utils.create_variables(network, [environment_spec.observations]) actor = acting.A2CActor(environment_spec=environment_spec, verbose_level=verbose_level, network=network, queue=queue) learner = learning.A2CLearner( environment_spec=environment_spec, network=network, dataset=queue, counter=counter, logger=logger, discount=discount, learning_rate=learning_rate, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_gradient_norm=max_gradient_norm, max_abs_reward=max_abs_reward, ) super().__init__(actor=actor, learner=learner, min_observations=0, observations_per_step=n_step_horizon)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_generator: iter, demonstration_ratio: float, model_directory: str, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, ): extra_spec = { 'core_state': network.initial_state(1), } # replay table # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # demonstation table. demonstration_table = reverb.Table( name='demonstration_table', sampler=reverb.selectors.Prioritized(0.8), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) # launch server self._server = reverb.Server([replay_table, demonstration_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay and demo sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) priority_function = {demonstration_table.name: lambda x: 1.} demo_adder = adders.SequenceAdder(client=reverb.Client(address), priority_fns=priority_function, **sequence_kwargs) # play demonstrations and write # exhaust the generator # TODO: MAX REPLAY SIZE _prev_action = 1 # this has to come from spec _add_first = True #include this to make datasets equivalent numpy_state = tf2_utils.to_numpy_squeeze(network.initial_state(1)) for ts, action in demonstration_generator: if _add_first: demo_adder.add_first(ts) _add_first = False else: demo_adder.add(_prev_action, ts, extras=(numpy_state, )) _prev_action = action # reset to new episode if ts.last(): _prev_action = None _add_first = True # replay dataset max_in_flight_samples_per_worker = 2 * batch_size if batch_size else 100 dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=adders.DEFAULT_PRIORITY_TABLE, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator= 2, # memory perf improvment attempt https://github.com/deepmind/acme/issues/33 sequence_length=sequence_length, emit_timesteps=sequence_length is None) # demonstation dataset d_dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=demonstration_table.name, max_in_flight_samples_per_worker=max_in_flight_samples_per_worker, num_workers_per_iterator=2, sequence_length=sequence_length, emit_timesteps=sequence_length is None) dataset = tf.data.experimental.sample_from_datasets( [dataset, d_dataset], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( directory=model_directory, subdirectory='r2d2_learner_v1', time_delta_minutes=15, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter(objects_to_save=None, time_delta_minutes=15000., directory=model_directory) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, obs_spec: specs.Array, action_spec: specs.DiscreteArray, network: snt.RNNCore, optimizer: tf.train.Optimizer, sequence_length: int, td_lambda: float, agent_discount: float, seed: int, ): """A recurrent actor-critic agent.""" del action_spec # unused tf.set_random_seed(seed) self._sequence_length = sequence_length self._num_transitions_in_buffer = 0 # Create the policy ops. obs = tf.placeholder(shape=(1,) + obs_spec.shape, dtype=obs_spec.dtype) mask = tf.placeholder(shape=(1,), dtype=tf.float32) state = self._placeholders_like(network.initial_state(batch_size=1)) (online_logits, _), next_state = network((obs, mask), state) action = tf.squeeze(tf.multinomial(online_logits, 1, output_dtype=tf.int32)) # Create placeholders and numpy arrays for learning from trajectories. shapes = [obs_spec.shape, (), (), (), ()] dtypes = [obs_spec.dtype, np.int32, np.float32, np.float32, np.float32] placeholders = [ tf.placeholder(shape=(self._sequence_length, 1) + shape, dtype=dtype) for shape, dtype in zip(shapes, dtypes)] observations, actions, rewards, discounts, masks = placeholders # Build actor and critic losses. (logits, values), final_state = tf.nn.dynamic_rnn( network, (observations, tf.expand_dims(masks, -1)), initial_state=state, dtype=tf.float32, time_major=True) (_, bootstrap_value), _ = network((obs, mask), final_state) values, bootstrap_value = tree.map_structure( lambda t: tf.squeeze(t, axis=-1), (values, bootstrap_value)) critic_loss, (advantages, _) = td_lambda_loss( state_values=values, rewards=rewards, pcontinues=agent_discount * discounts, bootstrap_value=bootstrap_value, lambda_=td_lambda) actor_loss = discrete_policy_gradient_loss(logits, actions, advantages) # Updates. grads_and_vars = optimizer.compute_gradients(actor_loss + critic_loss) grads, _ = tf.clip_by_global_norm([g for g, _ in grads_and_vars], 5.) grads_and_vars = [(g, pair[1]) for g, pair in zip(grads, grads_and_vars)] train_op = optimizer.apply_gradients(grads_and_vars) # Create TF session and callables. session = tf.Session() self._reset_fn = session.make_callable( network.initial_state(batch_size=1)) self._policy_fn = session.make_callable( [action, next_state], [obs, mask, state]) self._update_fn = session.make_callable( [train_op, final_state], placeholders + [obs, mask, state]) session.run(tf.global_variables_initializer()) # Initialize numpy buffers self.state = self._reset_fn() self.update_init_state = self._reset_fn() self.arrays = [ np.zeros(shape=(self._sequence_length, 1) + shape, dtype=dtype) for shape, dtype in zip(shapes, dtypes)]
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon_init: float = 1.0, epsilon_final: float = 0.01, epsilon_schedule_timesteps: float = 20000, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): if store_lstm_state: extra_spec = { 'core_state': tf2_utils.squeeze_batch_dim(network.initial_state(1)), } else: extra_spec = () replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. self._adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. dataset = make_reverb_dataset(server_address=address, batch_size=batch_size, prefetch_size=prefetch_size, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._saver = tf2_savers.Saver(learner.state) policy_network = snt.DeepRNN([ network, EpsilonGreedyExploration( epsilon_init=epsilon_init, epsilon_final=epsilon_final, epsilon_schedule_timesteps=epsilon_schedule_timesteps) ]) actor = actors.RecurrentActor(policy_network, self._adder, store_recurrent_state=store_lstm_state) max_Q_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=0.0).sample(), ]) self._deterministic_actor = actors.RecurrentActor( max_Q_network, self._adder, store_recurrent_state=store_lstm_state) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, prefetch_size: int = tf.data.experimental.AUTOTUNE, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, priority_exponent: float = 0.6, epsilon: float = 0.01, learning_rate: float = 1e-3, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, store_lstm_state: bool = True, max_priority_weight: float = 0.9, checkpoint: bool = True, ): replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Prioritized(priority_exponent), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=replay_period, sequence_length=sequence_length, ) # The dataset object to learn from. reverb_client = reverb.TFClient(address) extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) dataset = datasets.make_reverb_dataset( client=reverb_client, environment_spec=environment_spec, batch_size=batch_size, prefetch_size=prefetch_size, extra_spec=extra_spec, sequence_length=sequence_length) target_network = copy.deepcopy(network) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, sequence_length=sequence_length, dataset=dataset, reverb_client=reverb_client, counter=counter, logger=logger, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=store_lstm_state, max_priority_weight=max_priority_weight, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__(self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, target_network: snt.RNNCore, burn_in_length: int, trace_length: int, replay_period: int, demonstration_dataset: tf.data.Dataset, demonstration_ratio: float, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, batch_size: int = 32, target_update_period: int = 100, importance_sampling_exponent: float = 0.2, epsilon: float = 0.01, learning_rate: float = 1e-3, log_to_bigtable: bool = False, log_name: str = 'agent', checkpoint: bool = True, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0): extra_spec = { 'core_state': network.initial_state(1), } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) replay_table = reverb.Table( name=adders.DEFAULT_PRIORITY_TABLE, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), max_size=max_replay_size, rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), signature=adders.SequenceAdder.signature(environment_spec, extra_spec)) self._server = reverb.Server([replay_table], port=None) address = f'localhost:{self._server.port}' sequence_length = burn_in_length + trace_length + 1 # Component to add things into replay. sequence_kwargs = dict( period=replay_period, sequence_length=sequence_length, ) adder = adders.SequenceAdder(client=reverb.Client(address), **sequence_kwargs) # The dataset object to learn from. dataset = datasets.make_reverb_dataset(server_address=address, sequence_length=sequence_length) # Combine with demonstration dataset. transition = functools.partial(_sequence_from_episode, extra_spec=extra_spec, **sequence_kwargs) dataset_demos = demonstration_dataset.map(transition) dataset = tf.data.experimental.sample_from_datasets( [dataset, dataset_demos], [1 - demonstration_ratio, demonstration_ratio]) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) tf2_utils.create_variables(network, [environment_spec.observations]) tf2_utils.create_variables(target_network, [environment_spec.observations]) learner = learning.R2D2Learner( environment_spec=environment_spec, network=network, target_network=target_network, burn_in_length=burn_in_length, dataset=dataset, reverb_client=reverb.TFClient(address), counter=counter, logger=logger, sequence_length=sequence_length, discount=discount, target_update_period=target_update_period, importance_sampling_exponent=importance_sampling_exponent, max_replay_size=max_replay_size, learning_rate=learning_rate, store_lstm_state=False, ) self._checkpointer = tf2_savers.Checkpointer( subdirectory='r2d2_learner', time_delta_minutes=60, objects_to_save=learner.state, enable_checkpointing=checkpoint, ) self._snapshotter = tf2_savers.Snapshotter( objects_to_save={'network': network}, time_delta_minutes=60.) policy_network = snt.DeepRNN([ network, lambda qs: trfl.epsilon_greedy(qs, epsilon=epsilon).sample(), ]) actor = actors.RecurrentActor(policy_network, adder) observations_per_step = (float(replay_period * batch_size) / samples_per_insert) super().__init__(actor=actor, learner=learner, min_observations=replay_period * max(batch_size, min_replay_size), observations_per_step=observations_per_step)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.RNNCore, sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, max_abs_reward: Optional[float] = None, max_gradient_norm: Optional[float] = None, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') extra_spec = { 'core_state': network.initial_state(1), 'logits': tf.ones(shape=(1, num_actions), dtype=tf.float32) } # Remove batch dimensions. extra_spec = tf2_utils.squeeze_batch_dim(extra_spec) queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size, signature=adders.SequenceAdder.signature( environment_spec, extras_spec=extra_spec, sequence_length=sequence_length)) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. dataset = datasets.make_reverb_dataset(server_address=address, batch_size=batch_size) tf2_utils.create_variables(network, [environment_spec.observations]) self._actor = acting.IMPALAActor(network, adder) self._learner = learning.IMPALALearner( environment_spec=environment_spec, network=network, dataset=dataset, counter=counter, logger=logger, discount=discount, learning_rate=learning_rate, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_gradient_norm=max_gradient_norm, max_abs_reward=max_abs_reward, )
def build_learner(agent: snt.RNNCore, agent_state, env_outputs, agent_outputs, reward_clipping: str, discounting: float, baseline_cost: float, entropy_cost: float, policy_cloning_cost: float, value_cloning_cost: float, clip_grad_norm: float, clip_advantage: bool, learning_rate: float, batch_size: int, batch_size_from_replay: int, unroll_length: int, reward_scaling: float = 1.0, adam_beta1: float = 0.9, adam_beta2: float = 0.999, adam_epsilon: float = 1e-8, fixed_step_mul: bool = False, step_mul: int = 8): """Builds the learner loop. Returns: A tuple of (done, infos, and environment frames) where the environment frames tensor causes an update. """ learner_outputs, _ = agent.unroll(agent_outputs.action, env_outputs, agent_state) # Use last baseline value (from the value function) to bootstrap. bootstrap_value = learner_outputs.baseline[-1] # At this point, the environment outputs at time step `t` are the inputs that # lead to the learner_outputs at time step `t`. After the following shifting, # the actions in agent_outputs and learner_outputs at time step `t` is what # leads to the environment outputs at time step `t`. agent_outputs = tf.nest.map_structure(lambda t: t[1:], agent_outputs) agent_outputs_from_buffer = tf.nest.map_structure( lambda t: t[:, :batch_size_from_replay], agent_outputs) learner_outputs_from_buffer = tf.nest.map_structure( lambda t: t[:-1, :batch_size_from_replay], learner_outputs) rewards, infos, done, _ = tf.nest.map_structure(lambda t: t[1:], env_outputs) learner_outputs = tf.nest.map_structure(lambda t: t[:-1], learner_outputs) rewards = rewards * reward_scaling clipped_rewards = clip_rewards(rewards, reward_clipping) discounts = tf.to_float(~done) * discounting # We only need to learn a step_mul policy if the step multiplier is not fixed. if not fixed_step_mul: agent_outputs.action['step_mul'] = agent_outputs.step_mul agent_outputs.action_logits['step_mul'] = agent_outputs.step_mul_logits learner_outputs.action_logits[ 'step_mul'] = learner_outputs.step_mul_logits agent_outputs_from_buffer.action_logits[ 'step_mul'] = agent_outputs_from_buffer.step_mul_logits learner_outputs_from_buffer.action_logits[ 'step_mul'] = learner_outputs_from_buffer.step_mul_logits actions = tf.nest.flatten( tf.nest.map_structure(lambda x: tf.squeeze(x, axis=2), agent_outputs.action)) behaviour_logits = tf.nest.flatten(agent_outputs.action_logits) target_logits = tf.nest.flatten(learner_outputs.action_logits) behaviour_logits_from_buffer = tf.nest.flatten( agent_outputs_from_buffer.action_logits) target_logits_from_buffer = tf.nest.flatten( learner_outputs_from_buffer.action_logits) behaviour_neg_log_probs = sum( tf.nest.map_structure(compute_neg_log_probs, behaviour_logits, actions)) target_neg_log_probs = sum( tf.nest.map_structure(compute_neg_log_probs, target_logits, actions)) entropy_loss = sum( tf.nest.map_structure(compute_entropy_loss, target_logits)) with tf.device('/cpu'): vtrace_returns = vtrace.from_importance_weights( log_rhos=behaviour_neg_log_probs - target_neg_log_probs, discounts=discounts, rewards=clipped_rewards, values=learner_outputs.baseline, bootstrap_value=bootstrap_value) advantages = tf.stop_gradient(vtrace_returns.pg_advantages) # Clip advantages to strictly positive: if clip_advantage: advantages *= tf.where(advantages > 0.0, tf.ones_like(advantages), tf.zeros_like(advantages)) policy_gradient_loss = tf.reduce_sum( target_neg_log_probs * tf.stop_gradient(vtrace_returns.pg_advantages)) baseline_loss = .5 * tf.reduce_sum( tf.square(vtrace_returns.vs - learner_outputs.baseline)) entropy_loss = tf.reduce_sum(entropy_loss) # Compute the CLEAR policy cloning loss and the value cloning as described in https://arxiv.org/abs/1811.11682: policy_cloning_loss = sum( tf.nest.map_structure(compute_policy_cloning_loss, target_logits_from_buffer, behaviour_logits_from_buffer)) value_cloning_loss = tf.reduce_sum( tf.square(learner_outputs_from_buffer.baseline - tf.stop_gradient(agent_outputs_from_buffer.baseline))) # Combine individual losses, weighted by cost factors, to build overall loss: total_loss = policy_gradient_loss \ + baseline_cost * baseline_loss \ + entropy_cost * entropy_loss \ + policy_cloning_cost * policy_cloning_loss \ + value_cloning_cost * value_cloning_loss optimizer = tf.train.AdamOptimizer(learning_rate, adam_beta1, adam_beta2, adam_epsilon) parameters = tf.trainable_variables() gradients = tf.gradients(total_loss, parameters) gradients, grad_norm = clip_gradients(gradients, clip_grad_norm) train_op = optimizer.apply_gradients(list(zip(gradients, parameters))) # Merge updating the network and environment frames into a single tensor. with tf.control_dependencies([train_op]): if fixed_step_mul: step_env_frames = unroll_length * ( batch_size - batch_size_from_replay) * step_mul else: # do not use replay samples to calculate num environment frames step_env_frames = tf.to_int64( tf.reduce_sum( learner_outputs.step_mul[:, batch_size_from_replay:] + 1)) num_env_frames_and_train = tf.train.get_global_step().assign_add( step_env_frames) # Adding a few summaries. tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('entropy_cost', entropy_cost) tf.summary.scalar('loss/policy_gradient', policy_gradient_loss) tf.summary.scalar('loss/baseline', baseline_loss) tf.summary.scalar('loss/entropy', entropy_loss) tf.summary.scalar('loss/policy_cloning', policy_cloning_loss) tf.summary.scalar('loss/value_cloning', value_cloning_loss) tf.summary.scalar('loss/total_loss', total_loss) for action_name, action in agent_outputs.action.items(): tf.summary.histogram(f'action/{action_name}', action) tf.summary.scalar('grad_norm', grad_norm) return done, infos, num_env_frames_and_train
def build_critic_learner(agent: snt.RNNCore, agent_state, env_outputs, agent_outputs, reward_clipping: str, discounting: float, clip_grad_norm: float, learning_rate: float, batch_size: int, batch_size_from_replay: int, unroll_length: int, reward_scaling: float = 1.0, adam_beta1: float = 0.9, adam_beta2: float = 0.999, adam_epsilon: float = 1e-8, fixed_step_mul: bool = False, step_mul: int = 8): learner_outputs, _ = agent.unroll(agent_outputs.action, env_outputs, agent_state) bootstrap_value = learner_outputs.baseline[-1] rewards, infos, done, _ = tf.nest.map_structure(lambda t: t[1:], env_outputs) learner_outputs = tf.nest.map_structure(lambda t: t[:-1], learner_outputs) rewards = rewards * reward_scaling clipped_rewards = clip_rewards(rewards, reward_clipping) discounts = tf.to_float(~done) * discounting returns = tf.scan(lambda a, x: x[0] + x[1] * a, elems=[clipped_rewards, discounts], initializer=bootstrap_value, parallel_iterations=1, reverse=True, back_prop=False) baseline_loss = .5 * tf.reduce_sum( tf.square(returns - learner_outputs.baseline)) # Optimization optimizer = tf.train.AdamOptimizer(learning_rate, adam_beta1, adam_beta2, adam_epsilon) parameters = tf.trainable_variables() gradients = tf.gradients(baseline_loss, parameters) gradients, grad_norm = clip_gradients(gradients, clip_grad_norm) train_op = optimizer.apply_gradients(list(zip(gradients, parameters))) # Merge updating the network and environment frames into a single tensor. with tf.control_dependencies([train_op]): if fixed_step_mul: step_env_frames = unroll_length * ( batch_size - batch_size_from_replay) * step_mul else: # do not use replay samples to calculate num environment frames step_env_frames = tf.to_int64( tf.reduce_sum( learner_outputs.step_mul[:, batch_size_from_replay:] + 1)) num_env_frames_and_train = tf.train.get_global_step().assign_add( step_env_frames) # Adding a few summaries. tf.summary.scalar('ciritc_pretrain/learning_rate', learning_rate, ['ciritc_pretrain_summaries']) tf.summary.scalar('ciritc_pretrain/baseline_loss', baseline_loss, ['ciritc_pretrain_summaries']) tf.summary.scalar('ciritc_pretrain/grad_norm', grad_norm, ['ciritc_pretrain_summaries']) summary_op = tf.summary.merge_all('ciritc_pretrain_summaries') return num_env_frames_and_train, summary_op