def test_push_under_distribute_strategy( self, strategy: tf.distribute.Strategy) -> None: # Prepare nested variables under strategy scope to push into the server. with strategy.scope(): variables = _create_nested_variable() logging.info('Variables: %s', variables) # Push the input to the server. variable_container = reverb_variable_container.ReverbVariableContainer( self._server_address) variable_container.push(variables) # pytype: disable=wrong-arg-types # Check the content of the server. self._assert_nested_variable_in_server()
def train( root_dir: Text, environment_name: Text, strategy: tf.distribute.Strategy, replay_buffer_server_address: Text, variable_container_server_address: Text, suite_load_fn: Callable[[Text], py_environment.PyEnvironment] = suite_mujoco.load, # Training params learning_rate: float = 3e-4, batch_size: int = 256, num_iterations: int = 2000000, learner_iterations_per_call: int = 1) -> None: """Trains a DQN agent.""" # Get the specs from the environment. logging.info('Training SAC with learning rate: %f', learning_rate) env = suite_load_fn(environment_name) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(env)) # Create the agent. with strategy.scope(): train_step = train_utils.create_train_step() agent = _create_agent( train_step=train_step, observation_tensor_spec=observation_tensor_spec, action_tensor_spec=action_tensor_spec, time_step_tensor_spec=time_step_tensor_spec, learning_rate=learning_rate) # Create the policy saver which saves the initial model now, then it # periodically checkpoints the policy weigths. saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) save_model_trigger = triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=1000) # Create the variable container. variables = { reverb_variable_container.POLICY_KEY: agent.collect_policy.variables(), reverb_variable_container.TRAIN_STEP_KEY: train_step } variable_container = reverb_variable_container.ReverbVariableContainer( variable_container_server_address, table_names=[reverb_variable_container.DEFAULT_TABLE]) variable_container.push(variables) # Create the replay buffer. reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=reverb_replay_buffer.DEFAULT_TABLE, server_address=replay_buffer_server_address) # Initialize the dataset. def experience_dataset_fn(): with strategy.scope(): return reverb_replay.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(3) # Create the learner. learning_triggers = [ save_model_trigger, triggers.StepPerSecondLogTrigger(train_step, interval=1000) ] sac_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, strategy=strategy) # Run the training loop. while train_step.numpy() < num_iterations: sac_learner.run(iterations=learner_iterations_per_call) variable_container.push(variables)