def testPyEnvCompatible(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        observation_spec = array_spec.ArraySpec([2], np.float32)
        action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)

        observation_tensor_spec = tensor_spec.from_spec(observation_spec)
        action_tensor_spec = tensor_spec.from_spec(action_spec)
        time_step_tensor_spec = ts.time_step_spec(observation_tensor_spec)

        actor_net = actor_network.ActorNetwork(
            observation_tensor_spec,
            action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(time_step_tensor_spec,
                                             action_tensor_spec,
                                             actor_network=actor_net)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        # Env will validate action types automaticall since we provided the
        # action_spec.
        env = random_py_environment.RandomPyEnvironment(
            observation_spec, action_spec)

        time_step = env.reset()

        for _ in range(100):
            action_step = py_policy.action(time_step)
            time_step = env.step(action_step.action)
    def testBatchedPyEnvCompatible(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        actor_net = actor_network.ActorNetwork(
            self._observation_tensor_spec,
            self._action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(self._time_step_tensor_spec,
                                             self._action_tensor_spec,
                                             actor_network=actor_net)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy,
                                                       batch_time_steps=False)

        env_ctr = lambda: random_py_environment.RandomPyEnvironment(  # pylint: disable=g-long-lambda
            self._observation_spec, self._action_spec)

        env = batched_py_environment.BatchedPyEnvironment(
            [env_ctr() for _ in range(3)])
        time_step = env.reset()

        for _ in range(20):
            action_step = py_policy.action(time_step)
            time_step = env.step(action_step.action)
예제 #3
0
 def create_actor_network(self, actor_fc_layers,
                          actor_dropout_layer_params):
     return actor_network.ActorNetwork(
         spec.get_observation_spec(),
         spec.get_action_spec(),
         fc_layer_params=actor_fc_layers,
         dropout_layer_params=dropout_layer_params,
         name='actor_' + self.name)
예제 #4
0
    def test2DAction(self):
        batch_size = 3
        num_obs_dims = 5
        obs_spec = tensor_spec.TensorSpec([num_obs_dims], tf.float32)
        action_spec = tensor_spec.BoundedTensorSpec([2, 3], tf.float32, 2., 3.)
        actor_net = actor_network.ActorNetwork(obs_spec, action_spec)

        obs = tf.random.uniform([batch_size, num_obs_dims])
        actions, _ = actor_net(obs)
        self.assertAllEqual(actions.shape.as_list(),
                            [batch_size] + action_spec.shape.as_list())
        self.assertEqual(len(actor_net.trainable_variables), 2)
예제 #5
0
    def testActionsWithinRange(self):
        batch_size = 3
        num_obs_dims = 5
        obs_spec = tensor_spec.TensorSpec([num_obs_dims], tf.float32)
        action_spec = tensor_spec.BoundedTensorSpec([2, 3], tf.float32, 2., 3.)
        actor_net = actor_network.ActorNetwork(obs_spec, action_spec)

        obs = tf.random.uniform([batch_size, num_obs_dims])
        actions, _ = actor_net(obs)
        self.evaluate(tf.compat.v1.global_variables_initializer())
        actions_ = self.evaluate(actions)
        self.assertTrue(np.all(actions_ >= action_spec.minimum))
        self.assertTrue(np.all(actions_ <= action_spec.maximum))
예제 #6
0
    def testAddConvLayers(self):
        batch_size = 3
        num_obs_dims = 5
        obs_spec = tensor_spec.TensorSpec([3, 3, num_obs_dims], tf.float32)
        action_spec = tensor_spec.BoundedTensorSpec([1], tf.float32, 2., 3.)

        actor_net = actor_network.ActorNetwork(obs_spec,
                                               action_spec,
                                               conv_layer_params=[(16, 3, 2)])

        obs = tf.random.uniform([batch_size, 3, 3, num_obs_dims])
        actions, _ = actor_net(obs)
        self.assertAllEqual(actions.shape.as_list(),
                            [batch_size] + action_spec.shape.as_list())
        self.assertEqual(len(actor_net.trainable_variables), 4)
예제 #7
0
    def get_agent(self, env, params):
        """Returns a TensorFlow SAC-Agent
    
    Arguments:
        env {TFAPyEnvironment} -- Tensorflow-Agents PyEnvironment
        params {ParameterServer} -- ParameterServer from BARK
    
    Returns:
        agent -- tf-agent
    """

        # actor network
        actor_net = actor_network.ActorNetwork(
            env.observation_spec(),
            env.action_spec(),
            fc_layer_params=tuple(
                self._params["ML"]["Agent"]["actor_fc_layer_params"]),
        )

        # critic network
        critic_net = critic_network.CriticNetwork(
            (env.observation_spec(), env.action_spec()),
            observation_fc_layer_params=None,
            action_fc_layer_params=None,
            joint_fc_layer_params=tuple(
                self._params["ML"]["Agent"]["critic_joint_fc_layer_params"]))

        # agent
        # TODO(@hart): put all parameters in config file
        tf_agent = td3_agent.Td3Agent(
            env.time_step_spec(),
            env.action_spec(),
            critic_network=critic_net,
            actor_network=actor_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=self._params["ML"]["Agent"]
                ["actor_learning_rate"]),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=self._params["ML"]["Agent"]
                ["critic_learning_rate"]),
            debug_summaries=self._params["ML"]["Agent"]["debug_summaries"],
            train_step_counter=self._ckpt.step,
            gamma=0.99,
            target_update_tau=0.5,
            target_policy_noise_clip=0.5)

        tf_agent.initialize()
        return tf_agent
예제 #8
0
        def init_agent():
            """ a DDPG agent is set by default in the application"""
            # get the global step
            global_step = tf.compat.v1.train.get_or_create_global_step()

            # TODO: update this to get the optimizer from tensorflow 2.0 if possible
            optimizer = tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate)
            time_step_spec = time_step.time_step_spec(
                self._rl_app.observation_spec)
            actor_net = actor_network.ActorNetwork(
                self._rl_app.observation_spec,
                self._rl_app.action_spec,
                fc_layer_params=(400, 300))
            value_net = critic_network.CriticNetwork(
                (time_step_spec.observation, self._rl_app.action_spec),
                observation_fc_layer_params=(400, ),
                action_fc_layer_params=None,
                joint_fc_layer_params=(300, ))
            tf_agent = ddpg_agent.DdpgAgent(
                time_step_spec,
                self._rl_app.action_spec,
                actor_network=actor_net,
                critic_network=value_net,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=1e-4),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=1e-3),
                ou_stddev=0.2,
                ou_damping=0.15,
                target_update_tau=0.05,
                target_update_period=5,
                dqda_clipping=None,
                td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
                gamma=discount,
                reward_scale_factor=1.0,
                gradient_clipping=gradient_clipping,
                debug_summaries=True,
                summarize_grads_and_vars=True,
                train_step_counter=global_step)
            tf_agent.initialize()
            logger.info("tf_agent initialization is complete")

            # Optimize by wrapping some of the code in a graph using TF function.
            tf_agent.train = common.function(tf_agent.train)

            return tf_agent
예제 #9
0
    def testSavedModel(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        observation_spec = array_spec.ArraySpec([2], np.float32)
        action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)
        time_step_spec = ts.time_step_spec(observation_spec)

        observation_tensor_spec = tensor_spec.from_spec(observation_spec)
        action_tensor_spec = tensor_spec.from_spec(action_spec)
        time_step_tensor_spec = tensor_spec.from_spec(time_step_spec)

        actor_net = actor_network.ActorNetwork(
            observation_tensor_spec,
            action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(time_step_tensor_spec,
                                             action_tensor_spec,
                                             actor_network=actor_net)

        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(tf_policy)
        saver.save(path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, time_step_spec, action_spec)

        rng = np.random.RandomState()
        sample_time_step = array_spec.sample_spec_nest(time_step_spec, rng)
        batched_sample_time_step = nest_utils.batch_nested_array(
            sample_time_step)

        original_action = tf_policy.action(batched_sample_time_step)
        unbatched_original_action = nest_utils.unbatch_nested_tensors(
            original_action)
        original_action_np = tf.nest.map_structure(lambda t: t.numpy(),
                                                   unbatched_original_action)
        saved_policy_action = eager_py_policy.action(sample_time_step)

        tf.nest.assert_same_structure(saved_policy_action.action, action_spec)

        np.testing.assert_array_almost_equal(original_action_np.action,
                                             saved_policy_action.action)
    def testPyEnvCompatible(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        actor_net = actor_network.ActorNetwork(
            self._observation_tensor_spec,
            self._action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(self._time_step_tensor_spec,
                                             self._action_tensor_spec,
                                             actor_network=actor_net)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy)
        time_step = self._env.reset()

        for _ in range(100):
            action_step = py_policy.action(time_step)
            time_step = self._env.step(action_step.action)
예제 #11
0
  def setUp(self):
    super(SavedModelPYTFEagerPolicyTest, self).setUp()
    if not common.has_eager_been_enabled():
      self.skipTest('Only supported in eager.')

    observation_spec = array_spec.ArraySpec([2], np.float32)
    self.action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)
    self.time_step_spec = ts.time_step_spec(observation_spec)

    observation_tensor_spec = tensor_spec.from_spec(observation_spec)
    action_tensor_spec = tensor_spec.from_spec(self.action_spec)
    time_step_tensor_spec = tensor_spec.from_spec(self.time_step_spec)

    actor_net = actor_network.ActorNetwork(
        observation_tensor_spec,
        action_tensor_spec,
        fc_layer_params=(10,),
    )

    self.tf_policy = actor_policy.ActorPolicy(
        time_step_tensor_spec, action_tensor_spec, actor_network=actor_net)
예제 #12
0
def ACnetworks(environment, hyperparams) -> (actor_network, critic_network):
    observation_spec = environment.observation_spec()
    action_spec = environment.action_spec()

    actor_net = actor_network.ActorNetwork(
        input_tensor_spec=observation_spec,
        output_tensor_spec=action_spec,
        fc_layer_params=hyperparams['actor_fc_layer_params'],
        dropout_layer_params=hyperparams['actor_dropout'],
        activation_fn=tf.nn.relu
    )

    critic_net = critic_network.CriticNetwork(
        input_tensor_spec=(observation_spec, action_spec),
        observation_fc_layer_params=hyperparams['critic_obs_fc_layer_params'],
        action_fc_layer_params=hyperparams['critic_action_fc_layer_params'],
        joint_fc_layer_params=hyperparams['critic_joint_fc_layer_params'],
        joint_dropout_layer_params=hyperparams['critic_joint_dropout'],
        activation_fn=tf.nn.relu
    )

    return (actor_net, critic_net)
예제 #13
0
def DDPG_Bipedal(root_dir):

    # Setting up directories for results
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train' + '/' + str(run_id))
    eval_dir = os.path.join(root_dir, 'eval' + '/' + str(run_id))
    vid_dir = os.path.join(root_dir, 'vid' + '/' + str(run_id))

    # Set up Summary writer for training and evaluation
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        # Metric to record average return
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        # Metric to record average episode length
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    #Create global step
    global_step = tf.compat.v1.train.get_or_create_global_step()

    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Load Environment with different wrappers
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))
        eval_py_env = suite_gym.load(env_name)

        # Define Actor Network
        actorNN = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=(400, 300),
        )

        # Define Critic Network
        NN_input_specs = (tf_env.time_step_spec().observation,
                          tf_env.action_spec())

        criticNN = critic_network.CriticNetwork(
            NN_input_specs,
            observation_fc_layer_params=(400, ),
            action_fc_layer_params=None,
            joint_fc_layer_params=(300, ),
        )

        # Define & initialize DDPG Agent
        agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actorNN,
            critic_network=criticNN,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
            gamma=gamma,
            train_step_counter=global_step)
        agent.initialize()

        # Determine which train metrics to display with summary writer
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Set policies for evaluation, initial collection
        eval_policy = agent.policy  # Actor policy
        collect_policy = agent.collect_policy  # Actor policy with OUNoise

        # Set up replay buffer
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        # Define driver for initial replay buffer filling
        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,  # Initializes with random Parameters
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps)

        # Define collect driver for collect steps per iteration
        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            agent.train = common.function(agent.train)

        # Make 1000 random steps in tf_env and save in Replay Buffer
        logging.info(
            'Initializing replay buffer by collecting experience for 1000 steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        # Computes Evaluation Metrics
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset outputs steps in batches of 64
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=64,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(
                iterator)  #Get experience from dataset (replay buffer)
            return agent.train(experience)  #Train agent on that experience

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()  # Get start time
            # Collect data for replay buffer
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            # Train on experience
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='iterations_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                metric_utils.log_metrics(eval_metrics)
                if results['AverageReturn'].numpy() >= 230.0:
                    video_score = create_video(video_dir=vid_dir,
                                               env_name="BipedalWalker-v2",
                                               vid_policy=eval_policy,
                                               video_id=global_step.numpy())
    return train_loss
예제 #14
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        num_iterations=2000000,
        actor_fc_layers=(400, 300),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=100000,
        exploration_noise_std=0.1,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_update_period=2,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for checkpoints, summaries, and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for TD3."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_mujoco.load(env_name))

        actor_net = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers,
        )

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_network.CriticNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
        )

        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            exploration_noise_std=exploration_noise_std,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            actor_update_period=actor_update_period,
            dqda_clipping=dqda_clipping,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

        return train_loss
예제 #15
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=2000000,
    actor_fc_layers=(400, 300),
    critic_obs_fc_layers=(400,),
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(300,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=100000,
    exploration_noise_std=0.1,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    actor_update_period=2,
    actor_learning_rate=1e-4,
    critic_learning_rate=1e-3,
    dqda_clipping=None,
    td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
    gamma=0.995,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for TD3."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
    eval_py_env = suite_mujoco.load(env_name)

    actor_net = actor_network.ActorNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers,
    )

    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
    )

    tf_agent = td3_agent.Td3Agent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        exploration_noise_std=exploration_noise_std,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        actor_update_period=actor_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step,
    )

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    collect_policy = tf_agent.collect_policy
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=initial_collect_steps).run()

    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
    trajectories, unused_info = iterator.get_next()

    train_fn = common.function(tf_agent.train)
    train_op = train_fn(experience=trajectories)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      # TODO(b/126239733): Remove once Periodically can be saved.
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
          log=True,
      )

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops, global_step])

      timed_at_step = sess.run(global_step)
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _, global_step_val = train_step_call()
        time_acc += time.time() - start_time

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val,
                       loss_info_value.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
              log=True,
          )
예제 #16
0
def train_eval():
    # ==========================================================================
    # Setup Logging
    # ==========================================================================
    log_dir = get_log_dir(TrainingParameters.ENV, TrainingParameters.ALGO)
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        log_dir + '/train',
        flush_millis=TrainingParameters.LOG_FLUSH_STEP * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        log_dir + '/eval',
        flush_millis=EvaluationParameters.LOG_FLUSH_STEP * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(
            buffer_size=EvaluationParameters.NUM_EVAL_EPISODE),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=EvaluationParameters.NUM_EVAL_EPISODE)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(lambda: tf.math.equal(
            global_step % TrainingParameters.SUMMARY_INTERVAL, 0)):
        # ======================================================================
        # Create Parallel Environment
        # ======================================================================
        if TrainingParameters.NUM_AGENTS > 1:
            tf_env = tf_py_environment.TFPyEnvironment(
                parallel_py_environment.ParallelPyEnvironment([
                    lambda: TrainingParameters.ENV_LOAD_FN(TrainingParameters.
                                                           ENV)
                ] * TrainingParameters.NUM_AGENTS))
        else:
            tf_env = tf_py_environment.TFPyEnvironment(
                TrainingParameters.ENV_LOAD_FN(TrainingParameters.ENV))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            TrainingParameters.ENV_LOAD_FN(TrainingParameters.ENV))

        # ======================================================================
        # Create Actor Network
        # ======================================================================
        actor_net = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=TrainingParameters.ACTOR_LAYERS,
            activation_fn=TrainingParameters.ACTOR_ACTIVATION)

        # ======================================================================
        # Create Critic Network and Agent
        # ======================================================================

        critic_input_tensor_spec = (tf_env.time_step_spec().observation,
                                    tf_env.action_spec())
        if TrainingParameters.ALGO == 'D4PG':
            critic_net = distributional_critic_network.DistributionalCriticNetwork(
                critic_input_tensor_spec,
                num_atoms=TrainingParameters.NUM_ATOMS,
                observation_fc_layer_params=TrainingParameters.
                CRITIC_OBSERVATION_LAYERS,
                action_fc_layer_params=TrainingParameters.CRITIC_ACTION_LAYERS,
                joint_fc_layer_params=TrainingParameters.CRITIC_JOINT_LAYERS,
                activation_fn=TrainingParameters.CRITIC_ACTIVATION)
            tf_agent = d4pg_agent.D4pgAgent(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                actor_network=actor_net,
                critic_network=critic_net,
                min_v=TrainingParameters.V_MIN,
                max_v=TrainingParameters.V_MAX,
                n_step_return=TrainingParameters.N_STEP_RETURN,
                actor_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=TrainingParameters.ACTOR_LEARNING_RATE),
                actor_l2_lambda=TrainingParameters.ACTOR_L2_LAMBDA,
                critic_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=TrainingParameters.CRITIC_LEARNING_RATE),
                critic_l2_lambda=TrainingParameters.CRITIC_L2_LAMBDA,
                ou_stddev=TrainingParameters.OU_STDDEV,
                ou_damping=TrainingParameters.OU_DAMPING,
                target_update_tau=TrainingParameters.TAU,
                target_update_period=TrainingParameters.TARGET_UPDATE_PERIOD,
                dqda_clipping=None,
                td_errors_loss_fn=TrainingParameters.TD_ERROR_LOSS_FN,
                gamma=TrainingParameters.DISCOUNT_RATE,
                reward_scale_factor=TrainingParameters.REWARD_SCALE_FACTOR,
                gradient_clipping=None,
                debug_summaries=True,
                summarize_grads_and_vars=True,
                train_step_counter=global_step)

        elif TrainingParameters.ALGO == 'DDPG':
            critic_net = critic_network.CriticNetwork(
                critic_input_tensor_spec,
                observation_fc_layer_params=TrainingParameters.
                CRITIC_OBSERVATION_LAYERS,
                action_fc_layer_params=TrainingParameters.CRITIC_ACTION_LAYERS,
                joint_fc_layer_params=TrainingParameters.CRITIC_JOINT_LAYERS,
            )
            tf_agent = ddpg_agent.DdpgAgent(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=TrainingParameters.ACTOR_LEARNING_RATE),
                critic_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=TrainingParameters.CRITIC_LEARNING_RATE),
                train_step_counter=global_step,
                gamma=TrainingParameters.DISCOUNT_RATE,
                td_errors_loss_fn=TrainingParameters.TD_ERROR_LOSS_FN,
                target_update_tau=TrainingParameters.TAU,
                target_update_period=TrainingParameters.TARGET_UPDATE_PERIOD)
        else:
            raise ValueError('Received Wrong ALGO in params.py')

        tf_agent.initialize()

        # ======================================================================
        # Setup Replay Buffer, Trajectory Collector, and Training Operator
        # ======================================================================
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(
                batch_size=TrainingParameters.NUM_AGENTS),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=TrainingParameters.NUM_AGENTS),
        ]

        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=TrainingParameters.REPLAY_MEM_SIZE)

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=TrainingParameters.INITIAL_REPLAY_MEM_SIZE)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=TrainingParameters.NUM_DATA_COLLECT)

        # Dataset generates trajectories with shape [B x (N_STEP_RETURN+1) x...]
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=TrainingParameters.BATCH_SIZE,
            num_steps=TrainingParameters.N_STEP_RETURN + 1).prefetch(
                tf.data.experimental.AUTOTUNE)

        dataset_iterator = iter(dataset)

        def train_step():
            experience, _ = next(dataset_iterator)
            return tf_agent.train(experience)

        if TrainingParameters.USE_TF_FUNCTION:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)
            train_step = common.function(train_step)

        # ======================================================================
        # Collecting Data
        # ======================================================================
        initial_collect_driver.run()

        stop_threading_event = threading.Event()
        threads = []
        threads.append(
            threading.Thread(target=run_background,
                             args=(collect_driver.run, stop_threading_event)))

        for thread in threads:
            thread.start()

        # ======================================================================
        # Initial Policy Evaluation
        # ======================================================================
        eval_policy = tf_agent.policy
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=EvaluationParameters.NUM_EVAL_EPISODE,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        metric_utils.log_metrics(eval_metrics)

        # ======================================================================
        # Train the Agent
        # ======================================================================
        timed_at_step = global_step.numpy()
        time_acc = 0
        for _ in range(TrainingParameters.NUM_ITERATION):
            start_time = time.time()
            for _ in range(TrainingParameters.NUM_STEP_PER_ITERATION):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % TrainingParameters.LOG_INTERVAL == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % EvaluationParameters.EVAL_INTERVAL == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=EvaluationParameters.NUM_EVAL_EPISODE,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                metric_utils.log_metrics(eval_metrics)
                policy_saver.PolicySaver(eval_policy).save(log_dir +
                                                           '/eval/policy_%d' %
                                                           global_step)

        stop_threading_event.set()
예제 #17
0
파일: train_eval.py 프로젝트: fxia22/agents
def train_eval(
        root_dir,
        gpu=0,
        env_load_fn=None,
        model_ids=None,
        eval_env_mode='headless',
        num_iterations=1000000,
        conv_layer_params=None,
        encoder_fc_layers=[256],
        actor_fc_layers=[400, 300],
        critic_obs_fc_layers=[400],
        critic_action_fc_layers=None,
        critic_joint_fc_layers=[300],
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        num_parallel_environments=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        eval_only=False,
        eval_deterministic=False,
        num_parallel_environments_eval=1,
        model_ids_eval=None,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=10000,
        rb_checkpoint_interval=50000,
        log_interval=100,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            py_metrics.AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments_eval),
        batched_py_metric.BatchedPyMetric(
            py_metrics.AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments_eval),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if model_ids is None:
            model_ids = [None] * num_parallel_environments
        else:
            assert len(model_ids) == num_parallel_environments, \
                'model ids provided, but length not equal to num_parallel_environments'

        if model_ids_eval is None:
            model_ids_eval = [None] * num_parallel_environments_eval
        else:
            assert len(model_ids_eval) == num_parallel_environments_eval,\
                'model ids eval provided, but length not equal to num_parallel_environments_eval'

        tf_py_env = [
            lambda model_id=model_ids[i]: env_load_fn(model_id, 'headless', gpu
                                                      )
            for i in range(num_parallel_environments)
        ]
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(tf_py_env))

        if eval_env_mode == 'gui':
            assert num_parallel_environments_eval == 1, 'only one GUI env is allowed'
        eval_py_env = [
            lambda model_id=model_ids_eval[i]: env_load_fn(
                model_id, eval_env_mode, gpu)
            for i in range(num_parallel_environments_eval)
        ]
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            eval_py_env)

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()
        print('observation_spec', observation_spec)
        print('action_spec', action_spec)

        glorot_uniform_initializer = tf.compat.v1.keras.initializers.glorot_uniform(
        )
        preprocessing_layers = {
            'depth_seg':
            tf.keras.Sequential(
                mlp_layers(
                    conv_layer_params=conv_layer_params,
                    fc_layer_params=encoder_fc_layers,
                    kernel_initializer=glorot_uniform_initializer,
                )),
            'sensor':
            tf.keras.Sequential(
                mlp_layers(
                    conv_layer_params=None,
                    fc_layer_params=encoder_fc_layers,
                    kernel_initializer=glorot_uniform_initializer,
                )),
        }
        preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

        actor_net = actor_network.ActorNetwork(
            observation_spec,
            action_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            fc_layer_params=actor_fc_layers,
            kernel_initializer=glorot_uniform_initializer,
        )

        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer=glorot_uniform_initializer,
        )

        tf_agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.compat.v1.Session(config=config)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        if eval_deterministic:
            eval_py_policy = py_tf_policy.PyTFPolicy(
                greedy_policy.GreedyPolicy(tf_agent.policy))
        else:
            eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                buffer_size=100, batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=100, batch_size=num_parallel_environments),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps * num_parallel_environments).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration *
            num_parallel_environments).run()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=5,
            sample_batch_size=5 * batch_size,
            num_steps=2).apply(tf.data.experimental.unbatch()).filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()
        train_op = tf_agent.train(trajectories)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics))

        with eval_summary_writer.as_default(), tf.compat.v2.summary.record_if(
                True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=step_metrics)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        init_agent_op = tf_agent.initialize()
        with sess.as_default():
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)

            if eval_only:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=0,
                    callback=eval_metrics_callback,
                    tf_summaries=False,
                    log=True,
                )
                episodes = eval_py_env.get_stored_episodes()
                episodes = [
                    episode for sublist in episodes for episode in sublist
                ][:num_eval_episodes]
                metrics = episode_utils.get_metrics(episodes)
                for key in sorted(metrics.keys()):
                    print(key, ':', metrics[key])

                save_path = os.path.join(eval_dir, 'episodes_vis.pkl')
                episode_utils.save(episodes, save_path)
                print('EVAL DONE')
                return

            # Initialize training.
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            global_step_val = sess.run(global_step)
            if global_step_val == 0:
                # Initial eval of randomly initialized policy
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=0,
                    callback=eval_metrics_callback,
                    tf_summaries=True,
                    log=True,
                )
                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = sess.run(global_step)
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                # print('collect:', time.time() - start_time)

                # train_start_time = time.time()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                # print('train:', time.time() - train_start_time)

                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=0,
                        callback=eval_metrics_callback,
                        tf_summaries=True,
                        log=True,
                    )
                    with eval_summary_writer.as_default(
                    ), tf.compat.v2.summary.record_if(True):
                        with tf.name_scope('Metrics/'):
                            episodes = eval_py_env.get_stored_episodes()
                            episodes = [
                                episode for sublist in episodes
                                for episode in sublist
                            ][:num_eval_episodes]
                            metrics = episode_utils.get_metrics(episodes)
                            for key in sorted(metrics.keys()):
                                print(key, ':', metrics[key])
                                metric_op = tf.compat.v2.summary.scalar(
                                    name=key,
                                    data=metrics[key],
                                    step=global_step_val)
                                sess.run(metric_op)
                    sess.run(eval_summary_flush_op)

        sess.close()
예제 #18
0
    target_update_tau = 0.001
    target_update_period = 5
    actor_learning_rate = 0.0002  # @param {type:"number"}
    critic_learning_rate = 0.002
elif env_name == 'BipedalWalker-v2':
    num_iterations = 700000
    fc_layer_params = (400, 300)
    critic_fc_layer_params = (400, 300)
    critic_obs_layer_params = None
    target_update_tau = 0.001
    target_update_period = 5
    actor_learning_rate = 0.000025  # @param {type:"number"}
    critic_learning_rate = 0.00025

actor_net = actor_network.ActorNetwork(train_env.observation_spec(),
                                       train_env.action_spec(),
                                       fc_layer_params=fc_layer_params)

critic_net_input_specs = (train_env.observation_spec(),
                          train_env.action_spec())

critic_net = critic_network.CriticNetwork(
    critic_net_input_specs,
    observation_fc_layer_params=critic_obs_layer_params,
    action_fc_layer_params=None,
    joint_fc_layer_params=critic_fc_layer_params,
)

global_step = tf.compat.v1.train.get_or_create_global_step()
tf_agent = ddpg_agent.DdpgAgent(
    train_env.time_step_spec(),
예제 #19
0
def train_eval():
    """A simple train and eval for DDPG."""
    logdir = FLAGS.logdir
    train_dir = os.path.join(logdir, 'train')
    eval_dir = os.path.join(logdir, 'eval')

    summary_flush_millis = 10 * 1000
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summary_flush_millis)
    train_summary_writer.set_as_default()
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summary_flush_millis)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=FLAGS.num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=FLAGS.num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % 1000, 0)):
        env = FLAGS.environment
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env))
        eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env))

        if FLAGS.use_rnn:
            actor_net = actor_rnn_network.ActorRnnNetwork(
                tf_env.time_step_spec().observation,
                tf_env.action_spec(),
                input_fc_layer_params=parse_str_flag(FLAGS.actor_fc_layers),
                lstm_size=parse_str_flag(FLAGS.actor_lstm_sizes),
                output_fc_layer_params=FLAGS.actor_output_fc_layers)
            critic_net = critic_rnn_network.CriticRnnNetwork(
                input_tensor_spec=(tf_env.time_step_spec().observation,
                                   tf_env.action_spec()),
                observation_fc_layer_params=parse_str_flag(
                    FLAGS.critic_obs_fc_layers),
                action_fc_layer_params=parse_str_flag(
                    FLAGS.critic_action_fc_layers),
                joint_fc_layer_params=parse_str_flag(
                    FLAGS.critic_joint_fc_layers),
                lstm_size=parse_str_flag(FLAGS.critic_lstm_sizes),
                output_fc_layer_params=parse_str_flag(
                    FLAGS.critic_output_fc_layers))
        else:
            actor_net = actor_network.ActorNetwork(
                tf_env.time_step_spec().observation,
                tf_env.action_spec(),
                fc_layer_params=parse_str_flag(FLAGS.actor_fc_layers))
            critic_net = critic_network.CriticNetwork(
                input_tensor_spec=(tf_env.time_step_spec().observation,
                                   tf_env.action_spec()),
                observation_fc_layer_params=parse_str_flag(
                    FLAGS.critic_obs_fc_layers),
                action_fc_layer_params=parse_str_flag(
                    FLAGS.critic_action_fc_layers),
                joint_fc_layer_params=parse_str_flag(
                    FLAGS.critic_joint_fc_layers))

        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.train.AdamOptimizer(
                learning_rate=FLAGS.actor_learning_rate),
            critic_optimizer=tf.train.AdamOptimizer(
                learning_rate=FLAGS.critic_learning_rate),
            exploration_noise_std=FLAGS.exploration_noise_std,
            target_update_tau=FLAGS.target_update_tau,
            target_update_period=FLAGS.target_update_period,
            td_errors_loss_fn=None,  #tf.compat.v1.losses.huber_loss,
            gamma=FLAGS.gamma,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=FLAGS.replay_buffer_size)

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=FLAGS.initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=FLAGS.collect_steps_per_iteration)

        initial_collect_driver.run = common.function(
            initial_collect_driver.run)
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)
        # Collect initial replay data.
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=FLAGS.num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset to generate trajectories.
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=FLAGS.batch_size,
            num_steps=(FLAGS.train_sequence_length + 1)).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        train_step = common.function(train_step)

        for _ in range(FLAGS.num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(FLAGS.train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % 1000 == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % 10000 == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=FLAGS.num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                metric_utils.log_metrics(eval_metrics)

        return train_loss
예제 #20
0
def train_eval(
        root_dir,
        offline_dir=None,
        random_seed=None,
        env_name='sawyer_push',
        eval_env_name=None,
        env_load_fn=get_env,
        max_episode_steps=1000,
        eval_episode_steps=1000,
        # The SAC paper reported:
        # Hopper and Cartpole results up to 1000000 iters,
        # Humanoid results up to 10000000 iters,
        # Other mujoco tasks up to 3000000 iters.
        num_iterations=3000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
        # HalfCheetah and Ant take 10000 initial collection steps.
        # Other mujoco tasks take 1000.
        # Different choices roughly keep the initial episodes about the same.
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        reset_goal_frequency=1000,  # virtual episode size for reset-free training
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        # reset-free parameters
        use_minimum=True,
        reset_lagrange_learning_rate=3e-4,
        value_threshold=None,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        # Td3 parameters
        actor_update_period=1,
        exploration_noise_std=0.1,
        target_policy_noise=0.1,
        target_policy_noise_clip=0.1,
        dqda_clipping=None,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        # video recording for the environment
        video_record_interval=10000,
        num_videos=0,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):

    start_time = time.time()

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    video_dir = os.path.join(eval_dir, 'videos')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)

        if FLAGS.use_reset_goals in [-1]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.GoalTerminalResetWrapper,
                num_success_states=FLAGS.num_success_states,
                full_reset_frequency=max_episode_steps), )
        elif FLAGS.use_reset_goals in [0, 1]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.ResetFreeWrapper,
                reset_goal_frequency=reset_goal_frequency,
                variable_horizon_for_reset=FLAGS.variable_reset_horizon,
                num_success_states=FLAGS.num_success_states,
                full_reset_frequency=max_episode_steps), )
        elif FLAGS.use_reset_goals in [2]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.CustomOracleResetWrapper,
                partial_reset_frequency=reset_goal_frequency,
                episodes_before_full_reset=max_episode_steps //
                reset_goal_frequency), )
        elif FLAGS.use_reset_goals in [3, 4]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.GoalTerminalResetFreeWrapper,
                reset_goal_frequency=reset_goal_frequency,
                num_success_states=FLAGS.num_success_states,
                full_reset_frequency=max_episode_steps), )
        elif FLAGS.use_reset_goals in [5, 7]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.CustomOracleResetGoalTerminalWrapper,
                partial_reset_frequency=reset_goal_frequency,
                episodes_before_full_reset=max_episode_steps //
                reset_goal_frequency), )
        elif FLAGS.use_reset_goals in [6]:
            gym_env_wrappers = (functools.partial(
                reset_free_wrapper.VariableGoalTerminalResetWrapper,
                full_reset_frequency=max_episode_steps), )

        if env_name == 'playpen_reduced':
            train_env_load_fn = functools.partial(
                env_load_fn, reset_at_goal=FLAGS.reset_at_goal)
        else:
            train_env_load_fn = env_load_fn

        env, env_train_metrics, env_eval_metrics, aux_info = train_env_load_fn(
            name=env_name,
            max_episode_steps=None,
            gym_env_wrappers=gym_env_wrappers)

        tf_env = tf_py_environment.TFPyEnvironment(env)
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(name=eval_env_name,
                        max_episode_steps=eval_episode_steps)[0])

        eval_metrics += env_eval_metrics

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        if FLAGS.agent_type == 'sac':
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                observation_spec,
                action_spec,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=functools.partial(
                    tanh_normal_projection_network.TanhNormalProjectionNetwork,
                    std_transform=std_clip_transform))
            critic_net = critic_network.CriticNetwork(
                (observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers,
                kernel_initializer='glorot_uniform',
                last_kernel_initializer='glorot_uniform',
            )

            critic_net_no_entropy = None
            critic_no_entropy_optimizer = None
            if FLAGS.use_no_entropy_q:
                critic_net_no_entropy = critic_network.CriticNetwork(
                    (observation_spec, action_spec),
                    observation_fc_layer_params=critic_obs_fc_layers,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    kernel_initializer='glorot_uniform',
                    last_kernel_initializer='glorot_uniform',
                    name='CriticNetworkNoEntropy1')
                critic_no_entropy_optimizer = tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate)

            tf_agent = SacAgent(
                time_step_spec,
                action_spec,
                num_action_samples=FLAGS.num_action_samples,
                actor_network=actor_net,
                critic_network=critic_net,
                critic_network_no_entropy=critic_net_no_entropy,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate),
                critic_no_entropy_optimizer=critic_no_entropy_optimizer,
                alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=alpha_learning_rate),
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)

        elif FLAGS.agent_type == 'td3':
            actor_net = actor_network.ActorNetwork(
                tf_env.time_step_spec().observation,
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
            )
            critic_net = critic_network.CriticNetwork(
                (observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers,
                kernel_initializer='glorot_uniform',
                last_kernel_initializer='glorot_uniform')

            tf_agent = Td3Agent(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate),
                exploration_noise_std=exploration_noise_std,
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                actor_update_period=actor_update_period,
                dqda_clipping=dqda_clipping,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                target_policy_noise=target_policy_noise,
                target_policy_noise_clip=target_policy_noise_clip,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
            )

        tf_agent.initialize()

        if FLAGS.use_reset_goals > 0:
            if FLAGS.use_reset_goals in [4, 5, 6]:
                reset_goal_generator = ScheduledResetGoal(
                    goal_dim=aux_info['reset_state_shape'][0],
                    num_success_for_switch=FLAGS.num_success_for_switch,
                    num_chunks=FLAGS.num_chunks,
                    name='ScheduledResetGoalGenerator')
            else:
                # distance to initial state distribution
                initial_state_distance = state_distribution_distance.L2Distance(
                    initial_state_shape=aux_info['reset_state_shape'])
                initial_state_distance.update(tf.constant(
                    aux_info['reset_states'], dtype=tf.float32),
                                              update_type='complete')

                if use_tf_functions:
                    initial_state_distance.distance = common.function(
                        initial_state_distance.distance)
                    tf_agent.compute_value = common.function(
                        tf_agent.compute_value)

                # initialize reset / practice goal proposer
                if reset_lagrange_learning_rate > 0:
                    reset_goal_generator = ResetGoalGenerator(
                        goal_dim=aux_info['reset_state_shape'][0],
                        compute_value_fn=tf_agent.compute_value,
                        distance_fn=initial_state_distance,
                        use_minimum=use_minimum,
                        value_threshold=value_threshold,
                        lagrange_variable_max=FLAGS.lagrange_max,
                        optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=reset_lagrange_learning_rate),
                        name='reset_goal_generator')
                else:
                    reset_goal_generator = FixedResetGoal(
                        distance_fn=initial_state_distance)

            # if use_tf_functions:
            #   reset_goal_generator.get_reset_goal = common.function(
            #       reset_goal_generator.get_reset_goal)

            # modify the reset-free wrapper to use the reset goal generator
            tf_env.pyenv.envs[0].set_reset_goal_fn(
                reset_goal_generator.get_reset_goal)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        if FLAGS.relabel_goals:
            cur_episode_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                data_spec=tf_agent.collect_data_spec,
                batch_size=1,
                scope='CurEpisodeReplayBuffer',
                max_length=int(2 *
                               min(reset_goal_frequency, max_episode_steps)))

            # NOTE: the buffer is replaced because cannot have two buffers.add_batch
            replay_observer = [cur_episode_buffer.add_batch]

        # initialize metrics and observers
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        train_metrics += env_train_metrics

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        eval_py_policy = py_tf_eager_policy.PyTFEagerPolicy(
            tf_agent.policy, use_tf_function=True)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)
        if use_tf_functions:
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        if offline_dir is not None:
            offline_data = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                data_spec=tf_agent.collect_data_spec,
                batch_size=1,
                max_length=int(1e5))  # this has to be 100_000
            offline_checkpointer = common.Checkpointer(
                ckpt_dir=offline_dir,
                max_to_keep=1,
                replay_buffer=offline_data)
            offline_checkpointer.initialize_or_restore()

            # set the reset candidates to be all the data in offline buffer
            if (FLAGS.use_reset_goals > 0 and reset_lagrange_learning_rate > 0
                ) or FLAGS.use_reset_goals in [4, 5, 6, 7]:
                tf_env.pyenv.envs[0].set_reset_candidates(
                    nest_utils.unbatch_nested_tensors(
                        offline_data.gather_all()))

        if replay_buffer.num_frames() == 0:
            if offline_dir is not None:
                copy_replay_buffer(offline_data, replay_buffer)
                print(replay_buffer.num_frames())

                # multiply offline data
                if FLAGS.relabel_offline_data:
                    data_multiplier(replay_buffer,
                                    tf_env.pyenv.envs[0].env.compute_reward)
                    print('after data multiplication:',
                          replay_buffer.num_frames())

            initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env,
                initial_collect_policy,
                observers=replay_observer + train_metrics,
                num_steps=1)
            if use_tf_functions:
                initial_collect_driver.run = common.function(
                    initial_collect_driver.run)

            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps with '
                'a random policy.', initial_collect_steps)

            time_step = None
            policy_state = collect_policy.get_initial_state(tf_env.batch_size)

            for iter_idx in range(initial_collect_steps):
                time_step, policy_state = initial_collect_driver.run(
                    time_step=time_step, policy_state=policy_state)

                if time_step.is_last() and FLAGS.relabel_goals:
                    reward_fn = tf_env.pyenv.envs[0].env.compute_reward
                    relabel_function(cur_episode_buffer, time_step, reward_fn,
                                     replay_buffer)
                    cur_episode_buffer.clear()

                if FLAGS.use_reset_goals > 0 and time_step.is_last(
                ) and FLAGS.num_reset_candidates > 0:
                    tf_env.pyenv.envs[0].set_reset_candidates(
                        replay_buffer.get_next(
                            sample_batch_size=FLAGS.num_reset_candidates)[0])

        else:
            time_step = None
            policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        # manual data save for plotting utils
        np_custom_save(os.path.join(eval_dir, 'eval_interval.npy'),
                       eval_interval)
        try:
            average_eval_return = np_custom_load(
                os.path.join(eval_dir, 'average_eval_return.npy')).tolist()
            average_eval_success = np_custom_load(
                os.path.join(eval_dir, 'average_eval_success.npy')).tolist()
            average_eval_final_success = np_custom_load(
                os.path.join(eval_dir,
                             'average_eval_final_success.npy')).tolist()
        except:  # pylint: disable=bare-except
            average_eval_return = []
            average_eval_success = []
            average_eval_final_success = []

        print('initialization_time:', time.time() - start_time)
        for iter_idx in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )

            if time_step.is_last() and FLAGS.relabel_goals:
                reward_fn = tf_env.pyenv.envs[0].env.compute_reward
                relabel_function(cur_episode_buffer, time_step, reward_fn,
                                 replay_buffer)
                cur_episode_buffer.clear()

            # reset goal generator updates
            if FLAGS.use_reset_goals > 0 and iter_idx % (
                    FLAGS.reset_goal_frequency *
                    collect_steps_per_iteration) == 0:
                if FLAGS.num_reset_candidates > 0:
                    tf_env.pyenv.envs[0].set_reset_candidates(
                        replay_buffer.get_next(
                            sample_batch_size=FLAGS.num_reset_candidates)[0])
                if reset_lagrange_learning_rate > 0:
                    reset_goal_generator.update_lagrange_multipliers()

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                if 'Heatmap' in train_metric.name:
                    if global_step_val % summary_interval == 0:
                        train_metric.tf_summaries(
                            train_step=global_step,
                            step_metrics=train_metrics[:2])
                else:
                    train_metric.tf_summaries(train_step=global_step,
                                              step_metrics=train_metrics[:2])

            if global_step_val % summary_interval == 0 and FLAGS.use_reset_goals > 0 and reset_lagrange_learning_rate > 0:
                reset_states, values, initial_state_distance_vals, lagrangian = reset_goal_generator.update_summaries(
                    step_counter=global_step)
                for vf_viz_metric in aux_info['value_fn_viz_metrics']:
                    vf_viz_metric.tf_summaries(reset_states,
                                               values,
                                               train_step=global_step,
                                               step_metrics=train_metrics[:2])

                if FLAGS.debug_value_fn_for_reset:
                    num_test_lagrange = 20
                    hyp_lagranges = [
                        1.0 * increment / num_test_lagrange
                        for increment in range(num_test_lagrange + 1)
                    ]

                    door_pos = reset_states[
                        np.argmin(initial_state_distance_vals.numpy() -
                                  lagrangian.numpy() * values.numpy())][3:5]
                    print('cur lagrange: %.2f, cur reset goal: (%.2f, %.2f)' %
                          (lagrangian.numpy(), door_pos[0], door_pos[1]))
                    for lagrange in hyp_lagranges:
                        door_pos = reset_states[
                            np.argmin(initial_state_distance_vals.numpy() -
                                      lagrange * values.numpy())][3:5]
                        print(
                            'test lagrange: %.2f, cur reset goal: (%.2f, %.2f)'
                            % (lagrange, door_pos[0], door_pos[1]))
                    print('\n')

            if global_step_val % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step_val)
                metric_utils.log_metrics(eval_metrics)

                # numpy saves for plotting
                if 'AverageReturn' in results.keys():
                    average_eval_return.append(
                        results['AverageReturn'].numpy())
                if 'EvalSuccessfulAtAnyStep' in results.keys():
                    average_eval_success.append(
                        results['EvalSuccessfulAtAnyStep'].numpy())
                if 'EvalSuccessfulEpisodes' in results.keys():
                    average_eval_final_success.append(
                        results['EvalSuccessfulEpisodes'].numpy())
                elif 'EvalSuccessfulAtLastStep' in results.keys():
                    average_eval_final_success.append(
                        results['EvalSuccessfulAtLastStep'].numpy())

                if average_eval_return:
                    np_custom_save(
                        os.path.join(eval_dir, 'average_eval_return.npy'),
                        average_eval_return)
                if average_eval_success:
                    np_custom_save(
                        os.path.join(eval_dir, 'average_eval_success.npy'),
                        average_eval_success)
                if average_eval_final_success:
                    np_custom_save(
                        os.path.join(eval_dir,
                                     'average_eval_final_success.npy'),
                        average_eval_final_success)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)

            if global_step_val % video_record_interval == 0:
                for video_idx in range(num_videos):
                    video_name = os.path.join(
                        video_dir, str(global_step_val),
                        'video_' + str(video_idx) + '.mp4')
                    record_video(
                        lambda: env_load_fn(  # pylint: disable=g-long-lambda
                            name=env_name,
                            max_episode_steps=max_episode_steps)[0],
                        video_name,
                        eval_py_policy,
                        max_episode_length=eval_episode_steps)

        return train_loss
예제 #21
0
    env = PulseResponseEnv.PulseResponseEnv(Utarget, rho0s)
    if num_parallel_environments == 1:
        py_env = env
    else:
        py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env] * num_parallel_environments)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)

with strategy.scope():
    target_update_tau = 0.05
    target_update_period = 5
    ou_stddev = 0.2
    ou_damping = 0.15
    actor_net = actor_network.ActorNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=[128],
    )
    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=[128],
        action_fc_layer_params=[128],
        joint_fc_layer_params=[128],
    )
    agent = ddpg_agent.DdpgAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
예제 #22
0
def train_eval(
        root_dir,
        num_iterations=100000,
        seed=142,
        actor_fc_layers=(100, 100),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(100, 100),
        n_step_update=5,
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=100000,
        sigma=0.1,
        # Params for target update
        target_update_tau=0.01,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=1024,
        actor_learning_rate=0.5e-4,
        critic_learning_rate=0.5e-4,
        dqda_clipping=None,
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        prioritized_replay=False,
        rank_based=False,
        remove_boundaries=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=True,
        eval_metrics_callback=None,
        dir_suffix="-ddpg-5"):
    """A simple train and eval for DDPG."""
    tf.random.set_seed(seed)
    np.random.seed(seed + 1)
    seed_for_env = seed + 2

    root_dir = os.path.expanduser(root_dir)
    train_dir = root_dir + '/train' + dir_suffix
    eval_dir = root_dir + '/eval' + dir_suffix

    global_step = tf.Variable(0, name="global_step", dtype=tf.int64, trainable=False)

    # Need to set the seed in the enviroment, otherwise it uses a non-deterministic generator
    env_set_seed = GymEnvSeedWrapper(seed_for_env)
    tf_env = tf_py_environment.TFPyEnvironment(
        create_env(gym_env_wrappers=(env_set_seed,)))
    eval_tf_env = tf_py_environment.TFPyEnvironment(
        create_env(gym_env_wrappers=(env_set_seed,)))
    actor_net = actor_network.ActorNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers)

    critic_net = critic_network.CriticNetwork(
        (tf_env.time_step_spec().observation, tf_env.action_spec()),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        output_activation_fn=None)

    tf_agent = ddpg_agent_ex.DdpgAgentEx(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.keras.optimizers.Adam(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.keras.optimizers.Adam(
            learning_rate=critic_learning_rate),
        n_step_update=n_step_update,
        sigma=sigma,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    trainer = rllib_trainer.Trainer(
        train_dir,
        eval_dir,
        tf_agent,
        tf_env,
        eval_tf_env,
        global_step,
        batch_size=batch_size,
        initial_collect_steps=initial_collect_steps,
        collect_steps_per_iteration=collect_steps_per_iteration,
        train_steps_per_iteration=train_steps_per_iteration,
        remove_boundaries=remove_boundaries,
        replay_buffer_capacity=replay_buffer_capacity,
        prioritized_replay=prioritized_replay,
        rank_based=rank_based,
        summaries_flush_secs=summaries_flush_secs,
        summary_interval=summary_interval,
        eval_interval=eval_interval,
        log_interval=log_interval,
        num_eval_episodes=num_eval_episodes,
        use_tf_functions=use_tf_functions)
    result = trainer.train(num_iterations)

    return result
예제 #23
0
    spec_utils.get_tensor_specs(train_env))

#######Networks#####
#conv_layer_params = [(32,3,3),(32,3,3),(32,3,3)]
conv_layer_params = None
fc_layer_params = (400, 300)
kernel_initializer = tf.keras.initializers.VarianceScaling(
    scale=1. / 3., mode='fan_in', distribution='uniform')
final_layer_initializer = tf.keras.initializers.RandomUniform(minval=-0.0003,
                                                              maxval=0.0003)

actor_net = actor_network.ActorNetwork(
    observation_spec,
    action_spec,
    conv_layer_params=conv_layer_params,
    fc_layer_params=fc_layer_params,
    dropout_layer_params=None,
    activation_fn=tf.keras.activations.relu,
    kernel_initializer=kernel_initializer,
    last_kernel_initializer=final_layer_initializer,
    name='ActorNetwork')

critic_net = critic_network.CriticNetwork(
    (observation_spec, action_spec),
    observation_conv_layer_params=conv_layer_params,
    observation_fc_layer_params=(400, ),
    action_fc_layer_params=None,
    joint_fc_layer_params=(300, ),
    kernel_initializer=kernel_initializer,
    last_kernel_initializer=final_layer_initializer)

#target_actor_net = ActorNetwork(observation_spec,
예제 #24
0
def train_eval(
    root_dir,
    random_seed=None,
    env_name='sawyer_push',
    eval_env_name=None,
    env_load_fn=get_env,
    max_episode_steps=1000,
    eval_episode_steps=1000,
    # The SAC paper reported:
    # Hopper and Cartpole results up to 1000000 iters,
    # Humanoid results up to 10000000 iters,
    # Other mujoco tasks up to 3000000 iters.
    num_iterations=3000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
    # HalfCheetah and Ant take 10000 initial collection steps.
    # Other mujoco tasks take 1000.
    # Different choices roughly keep the initial episodes about the same.
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    reset_goal_frequency=1000,  # virtual episode size for reset-free training
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    # reset-free parameters
    use_minimum=True,
    reset_lagrange_learning_rate=3e-4,
    value_threshold=None,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    # Td3 parameters
    actor_update_period=1,
    exploration_noise_std=0.1,
    target_policy_noise=0.1,
    target_policy_noise_clip=0.1,
    dqda_clipping=None,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    # video recording for the environment
    video_record_interval=10000,
    num_videos=0,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  start_time = time.time()

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')
  video_dir = os.path.join(eval_dir, 'videos')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    if random_seed is not None:
      tf.compat.v1.set_random_seed(random_seed)
    env, env_train_metrics, env_eval_metrics, aux_info = env_load_fn(
        name=env_name,
        max_episode_steps=None,
        gym_env_wrappers=(functools.partial(
            reset_free_wrapper.ResetFreeWrapper,
            reset_goal_frequency=reset_goal_frequency,
            full_reset_frequency=max_episode_steps),))

    tf_env = tf_py_environment.TFPyEnvironment(env)
    eval_env_name = eval_env_name or env_name
    eval_tf_env = tf_py_environment.TFPyEnvironment(
        env_load_fn(name=eval_env_name,
                    max_episode_steps=eval_episode_steps)[0])

    eval_metrics += env_eval_metrics

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    if FLAGS.agent_type == 'sac':
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          observation_spec,
          action_spec,
          fc_layer_params=actor_fc_layers,
          continuous_projection_net=functools.partial(
              tanh_normal_projection_network.TanhNormalProjectionNetwork,
              std_transform=std_clip_transform),
          name='forward_actor')
      critic_net = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform',
          name='forward_critic')

      tf_agent = SacAgent(
          time_step_spec,
          action_spec,
          num_action_samples=FLAGS.num_action_samples,
          actor_network=actor_net,
          critic_network=critic_net,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=alpha_learning_rate),
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
          name='forward_agent')

      actor_net_rev = actor_distribution_network.ActorDistributionNetwork(
          observation_spec,
          action_spec,
          fc_layer_params=actor_fc_layers,
          continuous_projection_net=functools.partial(
              tanh_normal_projection_network.TanhNormalProjectionNetwork,
              std_transform=std_clip_transform),
          name='reverse_actor')

      critic_net_rev = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform',
          name='reverse_critic')

      tf_agent_rev = SacAgent(
          time_step_spec,
          action_spec,
          num_action_samples=FLAGS.num_action_samples,
          actor_network=actor_net_rev,
          critic_network=critic_net_rev,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=alpha_learning_rate),
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
          name='reverse_agent')

    elif FLAGS.agent_type == 'td3':
      actor_net = actor_network.ActorNetwork(
          tf_env.time_step_spec().observation,
          tf_env.action_spec(),
          fc_layer_params=actor_fc_layers,
      )
      critic_net = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform')

      tf_agent = Td3Agent(
          tf_env.time_step_spec(),
          tf_env.action_spec(),
          actor_network=actor_net,
          critic_network=critic_net,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          exploration_noise_std=exploration_noise_std,
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          actor_update_period=actor_update_period,
          dqda_clipping=dqda_clipping,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          target_policy_noise=target_policy_noise,
          target_policy_noise_clip=target_policy_noise_clip,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
      )

    tf_agent.initialize()
    tf_agent_rev.initialize()

    if FLAGS.use_reset_goals:
      # distance to initial state distribution
      initial_state_distance = state_distribution_distance.L2Distance(
          initial_state_shape=aux_info['reset_state_shape'])
      initial_state_distance.update(
          tf.constant(aux_info['reset_states'], dtype=tf.float32),
          update_type='complete')

      if use_tf_functions:
        initial_state_distance.distance = common.function(
            initial_state_distance.distance)
        tf_agent.compute_value = common.function(tf_agent.compute_value)

      # initialize reset / practice goal proposer
      if reset_lagrange_learning_rate > 0:
        reset_goal_generator = ResetGoalGenerator(
            goal_dim=aux_info['reset_state_shape'][0],
            num_reset_candidates=FLAGS.num_reset_candidates,
            compute_value_fn=tf_agent.compute_value,
            distance_fn=initial_state_distance,
            use_minimum=use_minimum,
            value_threshold=value_threshold,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=reset_lagrange_learning_rate),
            name='reset_goal_generator')
      else:
        reset_goal_generator = FixedResetGoal(
            distance_fn=initial_state_distance)

      # if use_tf_functions:
      #   reset_goal_generator.get_reset_goal = common.function(
      #       reset_goal_generator.get_reset_goal)

      # modify the reset-free wrapper to use the reset goal generator
      tf_env.pyenv.envs[0].set_reset_goal_fn(
          reset_goal_generator.get_reset_goal)

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    replay_buffer_rev = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent_rev.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer_rev = [replay_buffer_rev.add_batch]

    # initialize metrics and observers
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
    ]
    train_metrics += env_train_metrics
    train_metrics_rev = [
        tf_metrics.NumberOfEpisodes(name='NumberOfEpisodesRev'),
        tf_metrics.EnvironmentSteps(name='EnvironmentStepsRev'),
        tf_metrics.AverageReturnMetric(
            name='AverageReturnRev',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthRev',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
    ]
    train_metrics_rev += aux_info['train_metrics_rev']

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    eval_py_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_agent.policy, use_tf_function=True)

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    initial_collect_policy_rev = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy
    collect_policy_rev = tf_agent_rev.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'forward'),
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'forward', 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)
    # reverse policy savers
    train_checkpointer_rev = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'reverse'),
        agent=tf_agent_rev,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics_rev,
                                          'train_metrics_rev'))
    rb_checkpointer_rev = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer_rev'),
        max_to_keep=1,
        replay_buffer=replay_buffer_rev)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()
    train_checkpointer_rev.initialize_or_restore()
    rb_checkpointer_rev.initialize_or_restore()

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)
    collect_driver_rev = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy_rev,
        observers=replay_observer_rev + train_metrics_rev,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      collect_driver.run = common.function(collect_driver.run)
      collect_driver_rev.run = common.function(collect_driver_rev.run)
      tf_agent.train = common.function(tf_agent.train)
      tf_agent_rev.train = common.function(tf_agent_rev.train)

    if replay_buffer.num_frames() == 0:
      initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
          tf_env,
          initial_collect_policy,
          observers=replay_observer + train_metrics,
          num_steps=1)
      initial_collect_driver_rev = dynamic_step_driver.DynamicStepDriver(
          tf_env,
          initial_collect_policy_rev,
          observers=replay_observer_rev + train_metrics_rev,
          num_steps=1)
      # does not work for some reason
      if use_tf_functions:
        initial_collect_driver.run = common.function(initial_collect_driver.run)
        initial_collect_driver_rev.run = common.function(
            initial_collect_driver_rev.run)

      # Collect initial replay data.
      logging.info(
          'Initializing replay buffer by collecting experience for %d steps with '
          'a random policy.', initial_collect_steps)
      for iter_idx_initial in range(initial_collect_steps):
        if tf_env.pyenv.envs[0]._forward_or_reset_goal:
          initial_collect_driver.run()
        else:
          initial_collect_driver_rev.run()
        if FLAGS.use_reset_goals and iter_idx_initial % FLAGS.reset_goal_frequency == 0:
          if replay_buffer_rev.num_frames():
            reset_candidates_from_forward_buffer = replay_buffer.get_next(
                sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
            reset_candidates_from_reverse_buffer = replay_buffer_rev.get_next(
                sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
            flat_forward_tensors = tf.nest.flatten(
                reset_candidates_from_forward_buffer)
            flat_reverse_tensors = tf.nest.flatten(
                reset_candidates_from_reverse_buffer)
            concatenated_tensors = [
                tf.concat([x, y], axis=0)
                for x, y in zip(flat_forward_tensors, flat_reverse_tensors)
            ]
            reset_candidates = tf.nest.pack_sequence_as(
                reset_candidates_from_forward_buffer, concatenated_tensors)
            tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates)
          else:
            reset_candidates = replay_buffer.get_next(
                sample_batch_size=FLAGS.num_reset_candidates)[0]
            tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates)

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      return ~trajectories.is_boundary()[0]

    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size, num_steps=2).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)

    dataset_rev = replay_buffer_rev.as_dataset(
        sample_batch_size=batch_size, num_steps=2).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator_rev = iter(dataset_rev)

    def train_step_rev():
      experience_rev, _ = next(iterator_rev)
      return tf_agent_rev.train(experience_rev)

    if use_tf_functions:
      train_step = common.function(train_step)
      train_step_rev = common.function(train_step_rev)

    # manual data save for plotting utils
    np_on_cns_save(os.path.join(eval_dir, 'eval_interval.npy'), eval_interval)
    try:
      average_eval_return = np_on_cns_load(
          os.path.join(eval_dir, 'average_eval_return.npy')).tolist()
      average_eval_success = np_on_cns_load(
          os.path.join(eval_dir, 'average_eval_success.npy')).tolist()
    except:
      average_eval_return = []
      average_eval_success = []

    print('initialization_time:', time.time() - start_time)
    for iter_idx in range(num_iterations):
      start_time = time.time()
      if tf_env.pyenv.envs[0]._forward_or_reset_goal:
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )
      else:
        time_step, policy_state = collect_driver_rev.run(
            time_step=time_step,
            policy_state=policy_state,
        )

      # reset goal generator updates
      if FLAGS.use_reset_goals and iter_idx % (
          FLAGS.reset_goal_frequency * collect_steps_per_iteration) == 0:
        reset_candidates_from_forward_buffer = replay_buffer.get_next(
            sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
        reset_candidates_from_reverse_buffer = replay_buffer_rev.get_next(
            sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
        flat_forward_tensors = tf.nest.flatten(
            reset_candidates_from_forward_buffer)
        flat_reverse_tensors = tf.nest.flatten(
            reset_candidates_from_reverse_buffer)
        concatenated_tensors = [
            tf.concat([x, y], axis=0)
            for x, y in zip(flat_forward_tensors, flat_reverse_tensors)
        ]
        reset_candidates = tf.nest.pack_sequence_as(
            reset_candidates_from_forward_buffer, concatenated_tensors)
        tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates)
        if reset_lagrange_learning_rate > 0:
          reset_goal_generator.update_lagrange_multipliers()

      for _ in range(train_steps_per_iteration):
        train_loss_rev = train_step_rev()
        train_loss = train_step()

      time_acc += time.time() - start_time

      global_step_val = global_step.numpy()

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, train_loss.loss)
        logging.info('step = %d, loss_rev = %f', global_step_val,
                     train_loss_rev.loss)
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step_val
        time_acc = 0

      for train_metric in train_metrics:
        if 'Heatmap' in train_metric.name:
          if global_step_val % summary_interval == 0:
            train_metric.tf_summaries(
                train_step=global_step, step_metrics=train_metrics[:2])
        else:
          train_metric.tf_summaries(
              train_step=global_step, step_metrics=train_metrics[:2])

      for train_metric in train_metrics_rev:
        if 'Heatmap' in train_metric.name:
          if global_step_val % summary_interval == 0:
            train_metric.tf_summaries(
                train_step=global_step, step_metrics=train_metrics_rev[:2])
        else:
          train_metric.tf_summaries(
              train_step=global_step, step_metrics=train_metrics_rev[:2])

      if global_step_val % summary_interval == 0 and FLAGS.use_reset_goals:
        reset_goal_generator.update_summaries(step_counter=global_step)

      if global_step_val % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step_val)
        metric_utils.log_metrics(eval_metrics)

        # numpy saves for plotting
        average_eval_return.append(results['AverageReturn'].numpy())
        average_eval_success.append(results['EvalSuccessfulEpisodes'].numpy())
        np_on_cns_save(
            os.path.join(eval_dir, 'average_eval_return.npy'),
            average_eval_return)
        np_on_cns_save(
            os.path.join(eval_dir, 'average_eval_success.npy'),
            average_eval_success)

      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)
        train_checkpointer_rev.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
        rb_checkpointer_rev.save(global_step=global_step_val)

      if global_step_val % video_record_interval == 0:
        for video_idx in range(num_videos):
          video_name = os.path.join(video_dir, str(global_step_val),
                                    'video_' + str(video_idx) + '.mp4')
          record_video(
              lambda: env_load_fn(  # pylint: disable=g-long-lambda
                  name=env_name,
                  max_episode_steps=max_episode_steps)[0],
              video_name,
              eval_py_policy,
              max_episode_length=eval_episode_steps)

    return train_loss
예제 #25
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v1',
    env_load_fn=suite_mujoco.load,
    num_iterations=2000000,
    actor_fc_layers=(400, 300),
    critic_obs_fc_layers=(400,),
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(300,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    num_parallel_environments=1,
    replay_buffer_capacity=100000,
    ou_stddev=0.2,
    ou_damping=0.15,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    actor_learning_rate=1e-4,
    critic_learning_rate=1e-3,
    dqda_clipping=None,
    td_errors_loss_fn=tf.losses.huber_loss,
    gamma=0.995,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for DDPG."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.contrib.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.contrib.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  # TODO(kbanoop): Figure out if it is possible to avoid the with block.
  with tf.contrib.summary.record_summaries_every_n_global_steps(
      summary_interval):
    if num_parallel_environments > 1:
      tf_env = tf_py_environment.TFPyEnvironment(
          parallel_py_environment.ParallelPyEnvironment(
              [lambda: env_load_fn(env_name)] * num_parallel_environments))
    else:
      tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    eval_py_env = env_load_fn(env_name)

    actor_net = actor_network.ActorNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers,
    )

    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
    )

    tf_agent = ddpg_agent.DdpgAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        ou_stddev=ou_stddev,
        ou_damping=ou_damping,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec(),
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    global_step = tf.train.get_or_create_global_step()

    collect_policy = tf_agent.collect_policy()
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=initial_collect_steps).run()

    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)

    iterator = dataset.make_initializable_iterator()
    trajectories, unused_info = iterator.get_next()
    train_op = tf_agent.train(
        experience=trajectories, train_step_counter=global_step)

    train_checkpointer = common_utils.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=tf.contrib.checkpoint.List(train_metrics))
    policy_checkpointer = common_utils.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy(),
        global_step=global_step)
    rb_checkpointer = common_utils.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    for train_metric in train_metrics:
      train_metric.tf_summaries(step_metrics=train_metrics[:2])
    summary_op = tf.contrib.summary.all_summary_ops()

    with eval_summary_writer.as_default(), \
         tf.contrib.summary.always_record_summaries():
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries()

    init_agent_op = tf_agent.initialize()

    with tf.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      # TODO(sguada) Remove once Periodically can be saved.
      common_utils.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      tf.contrib.summary.initialize(session=sess)
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
      )

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_op, global_step])

      timed_at_step = sess.run(global_step)
      time_acc = 0
      steps_per_second_ph = tf.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.contrib.summary.scalar(
          name='global_steps/sec', tensor=steps_per_second_ph)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _, global_step_val = train_step_call()
        time_acc += time.time() - start_time

        if global_step_val % log_interval == 0:
          tf.logging.info('step = %d, loss = %f', global_step_val,
                          loss_info_value.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          tf.logging.info('%.3f steps/sec' % steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
          )
예제 #26
0
    dqda_clipping = config.dqda_clipping
    td_errors_loss_fn = config.td_errors_loss_fn
    gamma = config.gamma
    reward_scale_factor = config.reward_scale_factor
    gradient_clipping = config.gradient_clipping

    actor_learning_rate = config.actor_learning_rate
    critic_learning_rate = config.critic_learning_rate
    debug_summaries = config.debug_summaries
    summarize_grads_and_vars = config.summarize_grads_and_vars
    
    global_step = tf.compat.v1.train.get_or_create_global_step()

    actor_net = actor_network.ActorNetwork(
            train_env.time_step_spec().observation,
            train_env.action_spec(),
            fc_layer_params=actor_fc_layers,
        )

    critic_net_input_specs = (train_env.time_step_spec().observation,
                            train_env.action_spec())

    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
    )

    tf_agent = ddpg_agent.DdpgAgent(
        train_env.time_step_spec(),