Exemple #1
0
    def _run(self,
             strategy,
             batch_size=64,
             tf_function=True,
             replay_buffer_max_length=1000,
             train_steps=110,
             log_steps=10):
        """Runs Dqn CartPole environment.

    Args:
      strategy: Strategy to use, None is a valid value.
      batch_size: Total batch size to use for the run.
      tf_function: If True tf.function is used.
      replay_buffer_max_length: Max length of the replay buffer.
      train_steps: Number of steps to run.
      log_steps: How often to log step statistics, e.g. step time.
    """
        obs_spec = array_spec.BoundedArraySpec([
            4,
        ], np.float32, -4., 4.)
        action_spec = array_spec.BoundedArraySpec((), np.int64, 0, 1)

        py_env = random_py_environment.RandomPyEnvironment(
            obs_spec,
            action_spec,
            batch_size=1,
            reward_fn=lambda *_: np.random.randint(1, 10, 1))
        env = tf_py_environment.TFPyEnvironment(py_env)

        policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(),
                                                 env.action_spec())

        with distribution_strategy_utils.strategy_scope_context(strategy):
            q_net = q_network.QNetwork(env.time_step_spec().observation,
                                       env.action_spec(),
                                       fc_layer_params=(100, ))

            tf_agent = dqn_agent.DqnAgent(
                env.time_step_spec(),
                env.action_spec(),
                q_network=q_net,
                optimizer=tf.keras.optimizers.Adam(),
                td_errors_loss_fn=common.element_wise_squared_loss)
            tf_agent.initialize()
            print(q_net.summary())

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_max_length)

        driver = dynamic_step_driver.DynamicStepDriver(
            env, policy, [replay_buffer.add_batch])
        if tf_function:
            driver.run = common.function(driver.run)

        for _ in range(replay_buffer_max_length):
            driver.run()

        check_values = ['QNetwork/EncodingNetwork/dense/bias:0']
        initial_values = utils.get_initial_values(tf_agent, check_values)

        with distribution_strategy_utils.strategy_scope_context(strategy):
            dataset = replay_buffer.as_dataset(
                num_parallel_calls=tf.data.experimental.AUTOTUNE,
                sample_batch_size=batch_size,
                num_steps=2)
            if strategy:
                iterator = iter(
                    strategy.experimental_distribute_dataset(dataset))
            else:
                iterator = iter(dataset)

            def train_step(inputs):
                experience, _ = inputs
                return tf_agent.train(experience)

            if tf_function:
                train_step = common.function(train_step)
            self.run_and_report(train_step,
                                strategy,
                                batch_size,
                                train_steps=train_steps,
                                log_steps=log_steps,
                                iterator=iterator)

        utils.check_values_changed(tf_agent, initial_values, check_values)
Exemple #2
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):

        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        trajectory_spec = trajectory.from_transition(
            time_step=tf_env.time_step_spec(),
            action_step=policy_step.PolicyStep(action=tf_env.action_spec()),
            next_time_step=tf_env.time_step_spec())
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=trajectory_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        q_net = q_network.QNetwork(tf_env.time_step_spec().observation,
                                   tf_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        tf_agent.initialize()
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy()
        collect_policy = tf_agent.collect_policy()

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        global_step = tf.compat.v1.train.get_or_create_global_step()

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(metrics, global_step.numpy())

        time_step = None
        policy_state = ()

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                experience, _ = next(iterator)
                train_loss = tf_agent.train(experience,
                                            train_step_counter=global_step)
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.contrib.summary.scalar(name='global_steps/sec',
                                          tensor=steps_per_sec)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                metrics = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(metrics, global_step.numpy())
        return train_loss
Exemple #3
0
agent = dqn_agent.DqnAgent(train_env.time_step_spec(),
                           train_env.action_spec(),
                           q_network=q_net,
                           optimizer=optimizer,
                           td_errors_loss_fn=common.element_wise_squared_loss,
                           train_step_counter=train_step_counter)

agent.initialize()

# POLICIES

eval_policy = agent.policy
collect_policy = agent.collect_policy

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()
random_policy.action(time_step)

# Metrics and Evaluation


#@test {"skip": true}
def compute_avg_return(environment, policy, num_episodes=10):

    total_return = 0.0
    for _ in range(num_episodes):

        time_step = environment.reset()
Exemple #4
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        eval_env_name=None,
        env_load_fn=suite_mujoco.load,
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
Exemple #5
0
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(environment.time_step_spec(),
                           environment.action_spec(),
                           q_network=q_net,
                           optimizer=optimizer,
                           td_errors_loss_fn=common.element_wise_squared_loss,
                           train_step_counter=train_step_counter)

agent.initialize()

#Policy

random_policy = random_tf_policy.RandomTFPolicy(environment.time_step_spec(),
                                                environment.action_spec())

time_step = environment.reset()
print(random_policy.action(time_step))

#Replay Buffer
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=batch_size,
    max_length=replay_buffer_max_length)

# print(compute_avg_return(environment, random_policy))

# time_step = environment.reset()
# print(time_step)
    base_dir = os.path.abspath(
        'experiments/env_logs/playpen_reduced/symmetric/')
    env_log_dir = os.path.join(base_dir, 'rc_o/traj1/')
    # env = ResetFreeWrapper(env, reset_goal_frequency=500, full_reset_frequency=max_episode_steps)
    env = GoalTerminalResetWrapper(
        env,
        episodes_before_full_reset=max_episode_steps // 500,
        goal_reset_frequency=500)
    # env = Monitor(env, env_log_dir, video_callable=lambda x: x % 1 == 0, force=True)

    env = wrap_env(env)
    tf_env = tf_py_environment.TFPyEnvironment(env)
    tf_env.render = env.render
    time_step_spec = tf_env.time_step_spec()
    action_spec = tf_env.action_spec()
    policy = random_tf_policy.RandomTFPolicy(action_spec=action_spec,
                                             time_step_spec=time_step_spec)
    collect_data_spec = trajectory.Trajectory(
        step_type=time_step_spec.step_type,
        observation=time_step_spec.observation,
        action=action_spec,
        policy_info=policy.info_spec,
        next_step_type=time_step_spec.step_type,
        reward=time_step_spec.reward,
        discount=time_step_spec.discount)
    offline_data = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=collect_data_spec, batch_size=1, max_length=int(1e5))
    rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        env_log_dir, 'replay_buffer'),
                                          max_to_keep=10_000,
                                          replay_buffer=offline_data)
    rb_checkpointer.initialize_or_restore()
Exemple #7
0
def train_eval(
        root_dir,
        env_name='MaskedCartPole-v0',
        num_iterations=100000,
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),
        train_sequence_length=10,
        # Params for collect
        initial_collect_steps=50,
        collect_episodes_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=10,
        batch_size=128,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=100,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_py_env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

        q_net = q_rnn_network.QRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=input_fc_layer_params,
            lstm_size=lstm_size,
            output_fc_layer_params=output_fc_layer_params)

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        experience, _ = iterator.get_next()
        loss_info = common.function(tf_agent.train)(experience=experience)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            logging.info('Collecting initial experience.')
            sess.run(initial_collect_op)

            # Compute evaluation metrics.
            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(loss_info)
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        log=True,
                        callback=eval_metrics_callback,
                    )
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=3000000,
    actor_fc_layers=(),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    dual_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=50000,
    policy_checkpoint_interval=50000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None,
    latent_dim=10,
    log_prob_reward_scale=0.0,
    predictor_updates_encoder=False,
    predict_prior=True,
    use_recurrent_actor=False,
    rnn_sequence_length=20,
    clip_max_stddev=10.0,
    clip_min_stddev=0.1,
    clip_mean=30.0,
    predictor_num_layers=2,
    use_identity_encoder=False,
    identity_encoder_single_stddev=False,
    kl_constraint=1.0,
    eval_dropout=(),
    use_residual_predictor=True,
    gym_kwargs=None,
    predict_prior_std=True,
    random_seed=0,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    if use_recurrent_actor:
        batch_size = batch_size // rnn_sequence_length
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):

        _build_env = functools.partial(
            suite_gym.load,
            environment_name=env_name,  # pylint: disable=invalid-name
            gym_env_wrappers=(),
            gym_kwargs=gym_kwargs)

        tf_env = tf_py_environment.TFPyEnvironment(_build_env())
        eval_vec = []  # (name, env, metrics)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes)
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(_build_env())
        name = ''
        eval_vec.append((name, eval_tf_env, eval_metrics))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()
        if latent_dim == 'obs':
            latent_dim = observation_spec.shape[0]

        def _activation(t):
            t1, t2 = tf.split(t, 2, axis=1)
            low = -np.inf if clip_mean is None else -clip_mean
            high = np.inf if clip_mean is None else clip_mean
            t1 = rpc_utils.squash_to_range(t1, low, high)

            if clip_min_stddev is None:
                low = -np.inf
            else:
                low = tf.math.log(tf.exp(clip_min_stddev) - 1.0)
            if clip_max_stddev is None:
                high = np.inf
            else:
                high = tf.math.log(tf.exp(clip_max_stddev) - 1.0)
            t2 = rpc_utils.squash_to_range(t2, low, high)
            return tf.concat([t1, t2], axis=1)

        if use_identity_encoder:
            assert latent_dim == observation_spec.shape[0]
            obs_input = tf.keras.layers.Input(observation_spec.shape)
            zeros = 0.0 * obs_input[:, :1]
            stddev_dim = 1 if identity_encoder_single_stddev else latent_dim
            pre_stddev = tf.keras.layers.Dense(stddev_dim,
                                               activation=None)(zeros)
            ones = zeros + tf.ones((1, latent_dim))
            pre_stddev = pre_stddev * ones  # Multiply to broadcast to latent_dim.
            pre_mean_stddev = tf.concat([obs_input, pre_stddev], axis=1)
            output = tfp.layers.IndependentNormal(latent_dim)(pre_mean_stddev)
            encoder_net = tf.keras.Model(inputs=obs_input, outputs=output)
        else:
            encoder_net = tf.keras.Sequential([
                tf.keras.layers.Dense(256, activation='relu'),
                tf.keras.layers.Dense(256, activation='relu'),
                tf.keras.layers.Dense(
                    tfp.layers.IndependentNormal.params_size(latent_dim),
                    activation=_activation,
                    kernel_initializer='glorot_uniform'),
                tfp.layers.IndependentNormal(latent_dim),
            ])

        # Build the predictor net
        obs_input = tf.keras.layers.Input(observation_spec.shape)
        action_input = tf.keras.layers.Input(action_spec.shape)

        class ConstantIndependentNormal(tfp.layers.IndependentNormal):
            """A keras layer that always returns N(0, 1) distribution."""
            def call(self, inputs):
                loc_scale = tf.concat([
                    tf.zeros((latent_dim, )),
                    tf.fill((latent_dim, ), tf.math.log(tf.exp(1.0) - 1))
                ],
                                      axis=0)
                # Multiple by [B x 1] tensor to broadcast batch dimension.
                loc_scale = loc_scale * tf.ones_like(inputs[:, :1])
                return super(ConstantIndependentNormal, self).call(loc_scale)

        if predict_prior:
            z = encoder_net(obs_input)
            if not predictor_updates_encoder:
                z = tf.stop_gradient(z)
            za = tf.concat([z, action_input], axis=1)
            if use_residual_predictor:
                za_input = tf.keras.layers.Input(za.shape[1])
                loc_scale = tf.keras.Sequential(
                    predictor_num_layers *
                    [tf.keras.layers.Dense(256, activation='relu')] + [  # pylint: disable=line-too-long
                        tf.keras.layers.Dense(tfp.layers.IndependentNormal.
                                              params_size(latent_dim),
                                              activation=_activation,
                                              kernel_initializer='zeros'),
                    ])(za_input)
                if predict_prior_std:
                    combined_loc_scale = tf.concat([
                        loc_scale[:, :latent_dim] + za_input[:, :latent_dim],
                        loc_scale[:, latent_dim:]
                    ],
                                                   axis=1)
                else:
                    # Note that softplus(log(e - 1)) = 1.
                    combined_loc_scale = tf.concat([
                        loc_scale[:, :latent_dim] + za_input[:, :latent_dim],
                        tf.math.log(np.e - 1) *
                        tf.ones_like(loc_scale[:, latent_dim:])
                    ],
                                                   axis=1)
                dist = tfp.layers.IndependentNormal(latent_dim)(
                    combined_loc_scale)
                output = tf.keras.Model(inputs=za_input, outputs=dist)(za)
            else:
                assert predict_prior_std
                output = tf.keras.Sequential(
                    predictor_num_layers *
                    [tf.keras.layers.Dense(256, activation='relu')] +  # pylint: disable=line-too-long
                    [
                        tf.keras.layers.Dense(tfp.layers.IndependentNormal.
                                              params_size(latent_dim),
                                              activation=_activation,
                                              kernel_initializer='zeros'),
                        tfp.layers.IndependentNormal(latent_dim),
                    ])(za)
        else:
            # scale is chosen by inverting the softplus function to equal 1.
            if len(obs_input.shape) > 2:
                input_reshaped = tf.reshape(
                    obs_input,
                    [-1, tf.math.reduce_prod(obs_input.shape[1:])])
                #  Multiply by [B x 1] tensor to broadcast batch dimension.
                za = tf.zeros(latent_dim + action_spec.shape[0], ) * tf.ones_like(input_reshaped[:, :1])  # pylint: disable=line-too-long
            else:
                #  Multiple by [B x 1] tensor to broadcast batch dimension.
                za = tf.zeros(latent_dim + action_spec.shape[0], ) * tf.ones_like(obs_input[:, :1])  # pylint: disable=line-too-long
            output = tf.keras.Sequential([
                ConstantIndependentNormal(latent_dim),
            ])(za)
        predictor_net = tf.keras.Model(inputs=(obs_input, action_input),
                                       outputs=output)
        if use_recurrent_actor:
            ActorClass = rpc_utils.RecurrentActorNet  # pylint: disable=invalid-name
        else:
            ActorClass = rpc_utils.ActorNet  # pylint: disable=invalid-name
        actor_net = ActorClass(input_tensor_spec=observation_spec,
                               output_tensor_spec=action_spec,
                               encoder=encoder_net,
                               predictor=predictor_net,
                               fc_layers=actor_fc_layers)

        critic_net = rpc_utils.CriticNet(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')
        critic_net_2 = None
        target_critic_net_1 = None
        target_critic_net_2 = None

        tf_agent = rpc_agent.RpAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            critic_network_2=critic_net_2,
            target_critic_network=target_critic_net_1,
            target_critic_network_2=target_critic_net_2,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        dual_optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=dual_learning_rate)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]
        kl_metric = rpc_utils.AverageKLMetric(encoder=encoder_net,
                                              predictor=predictor_net,
                                              batch_size=tf_env.batch_size)
        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        checkpoint_items = {
            'ckpt_dir': train_dir,
            'agent': tf_agent,
            'global_step': global_step,
            'metrics': metric_utils.MetricsGroup(train_metrics,
                                                 'train_metrics'),
            'dual_optimizer': dual_optimizer,
        }
        train_checkpointer = common.Checkpointer(**checkpoint_items)

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps,
            transition_observers=[kl_metric])

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration,
            transition_observers=[kl_metric])

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        if replay_buffer.num_frames() == 0:
            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps '
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        for name, eval_tf_env, eval_metrics in eval_vec:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics-%s' % name,
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics, prefix=name)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0
        train_time_acc = 0
        env_time_acc = 0

        if use_recurrent_actor:  # default from sac/train_eval_rnn.py
            num_steps = rnn_sequence_length + 1

            def _filter_invalid_transition(trajectories, unused_arg1):
                return tf.reduce_all(~trajectories.is_boundary()[:-1])

            tf_agent._as_transition = data_converter.AsTransition(  # pylint: disable=protected-access
                tf_agent.data_context,
                squeeze_time_dim=False)
        else:
            num_steps = 2

            def _filter_invalid_transition(trajectories, unused_arg1):
                return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size,
            num_steps=num_steps).unbatch().filter(_filter_invalid_transition)

        dataset = dataset.batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        @tf.function
        def train_step():
            experience, _ = next(iterator)

            prior = predictor_net(
                (experience.observation[:, 0], experience.action[:, 0]),
                training=False)
            z_next = encoder_net(experience.observation[:, 1], training=False)
            # predictor_kl is a vector of size batch_size.
            predictor_kl = tfp.distributions.kl_divergence(z_next, prior)

            with tf.GradientTape() as tape:
                tape.watch(actor_net._log_kl_coefficient)  # pylint: disable=protected-access
                dual_loss = -1.0 * actor_net._log_kl_coefficient * (  # pylint: disable=protected-access
                    tf.stop_gradient(tf.reduce_mean(predictor_kl)) -
                    kl_constraint)
            dual_grads = tape.gradient(dual_loss,
                                       [actor_net._log_kl_coefficient])  # pylint: disable=protected-access
            grads_and_vars = list(
                zip(dual_grads, [actor_net._log_kl_coefficient]))  # pylint: disable=protected-access
            dual_optimizer.apply_gradients(grads_and_vars)

            # Clip the dual variable so exp(log_kl_coef) <= 1e6.
            log_kl_coef = tf.clip_by_value(
                actor_net._log_kl_coefficient,  # pylint: disable=protected-access
                -1.0 * np.log(1e6),
                np.log(1e6))
            actor_net._log_kl_coefficient.assign(log_kl_coef)  # pylint: disable=protected-access

            with tf.name_scope('dual_loss'):
                tf.compat.v2.summary.scalar(name='dual_loss',
                                            data=tf.reduce_mean(dual_loss),
                                            step=global_step)
                tf.compat.v2.summary.scalar(
                    name='log_kl_coefficient',
                    data=actor_net._log_kl_coefficient,  # pylint: disable=protected-access
                    step=global_step)

            z_entropy = z_next.entropy()
            log_prob = prior.log_prob(z_next.sample())
            with tf.name_scope('rp-metrics'):
                common.generate_tensor_summaries('predictor_kl', predictor_kl,
                                                 global_step)
                common.generate_tensor_summaries('z_entropy', z_entropy,
                                                 global_step)
                common.generate_tensor_summaries('log_prob', log_prob,
                                                 global_step)
                common.generate_tensor_summaries('z_mean', z_next.mean(),
                                                 global_step)
                common.generate_tensor_summaries('z_stddev', z_next.stddev(),
                                                 global_step)
                common.generate_tensor_summaries('prior_mean', prior.mean(),
                                                 global_step)
                common.generate_tensor_summaries('prior_stddev',
                                                 prior.stddev(), global_step)

            if log_prob_reward_scale == 'auto':
                coef = tf.stop_gradient(tf.exp(actor_net._log_kl_coefficient))  # pylint: disable=protected-access
            else:
                coef = log_prob_reward_scale
            tf.debugging.check_numerics(tf.reduce_mean(predictor_kl),
                                        'predictor_kl is inf or nan.')
            tf.debugging.check_numerics(coef, 'coef is inf or nan.')
            new_reward = experience.reward - coef * predictor_kl[:, None]

            experience = experience._replace(reward=new_reward)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        # Save the hyperparameters
        operative_filename = os.path.join(root_dir, 'operative.gin')
        with tf.compat.v1.gfile.Open(operative_filename, 'w') as f:
            f.write(gin.operative_config_str())
            print(gin.operative_config_str())

        global_step_val = global_step.numpy()
        while global_step_val < num_iterations:
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            env_time_acc += time.time() - start_time
            train_start_time = time.time()
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            train_time_acc += time.time() - train_start_time
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                train_steps_per_sec = (global_step_val -
                                       timed_at_step) / train_time_acc
                logging.info('Train: %.3f steps/sec', train_steps_per_sec)
                tf.compat.v2.summary.scalar(name='train_steps_per_sec',
                                            data=train_steps_per_sec,
                                            step=global_step)
                env_steps_per_sec = (global_step_val -
                                     timed_at_step) / env_time_acc
                logging.info('Env: %.3f steps/sec', env_steps_per_sec)
                tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                            data=env_steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0
                train_time_acc = 0
                env_time_acc = 0

            for train_metric in train_metrics + [kl_metric]:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                start_time = time.time()
                for name, eval_tf_env, eval_metrics in eval_vec:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='Metrics-%s' % name,
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics, prefix=name)
                logging.info('Evaluation: %d min',
                             (time.time() - start_time) / 60)
                for prob_dropout in eval_dropout:
                    rpc_utils.eval_dropout_fn(eval_tf_env,
                                              actor_net,
                                              global_step,
                                              prob_dropout=prob_dropout)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
Exemple #9
0
def train_eval(
    root_dir,
    env_name='sawyer_reach',
    num_iterations=3000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    gamma=0.99,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=200000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    random_seed=0,
    max_future_steps=50,
    actor_std=None,
    log_subset=None,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env, eval_tf_env, obs_dim = c_learning_envs.load(env_name)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        if actor_std is None:
            proj_net = tanh_normal_projection_network.TanhNormalProjectionNetwork
        else:
            proj_net = functools.partial(
                tanh_normal_projection_network.TanhNormalProjectionNetwork,
                std_transform=lambda t: actor_std * tf.ones_like(t))

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=proj_net)
        critic_net = c_learning_utils.ClassifierCriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = c_learning_agent.CLearningAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=bce_loss,
            gamma=gamma,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes),
            c_learning_utils.FinalDistance(buffer_size=num_eval_episodes,
                                           obs_dim=obs_dim),
            c_learning_utils.MinimumDistance(buffer_size=num_eval_episodes,
                                             obs_dim=obs_dim),
            c_learning_utils.DeltaDistance(buffer_size=num_eval_episodes,
                                           obs_dim=obs_dim),
        ]
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
            c_learning_utils.InitialDistance(buffer_size=num_eval_episodes,
                                             batch_size=tf_env.batch_size,
                                             obs_dim=obs_dim),
            c_learning_utils.FinalDistance(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size,
                                           obs_dim=obs_dim),
            c_learning_utils.MinimumDistance(buffer_size=num_eval_episodes,
                                             batch_size=tf_env.batch_size,
                                             obs_dim=obs_dim),
            c_learning_utils.DeltaDistance(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size,
                                           obs_dim=obs_dim),
        ]
        if log_subset is not None:
            start_index, end_index = log_subset
            for name, metrics in [('train', train_metrics),
                                  ('eval', eval_metrics)]:
                metrics.extend([
                    c_learning_utils.InitialDistance(
                        buffer_size=num_eval_episodes,
                        batch_size=tf_env.batch_size
                        if name == 'train' else 10,
                        obs_dim=obs_dim,
                        start_index=start_index,
                        end_index=end_index,
                        name='SubsetInitialDistance'),
                    c_learning_utils.FinalDistance(
                        buffer_size=num_eval_episodes,
                        batch_size=tf_env.batch_size
                        if name == 'train' else 10,
                        obs_dim=obs_dim,
                        start_index=start_index,
                        end_index=end_index,
                        name='SubsetFinalDistance'),
                    c_learning_utils.MinimumDistance(
                        buffer_size=num_eval_episodes,
                        batch_size=tf_env.batch_size
                        if name == 'train' else 10,
                        obs_dim=obs_dim,
                        start_index=start_index,
                        end_index=end_index,
                        name='SubsetMinimumDistance'),
                    c_learning_utils.DeltaDistance(
                        buffer_size=num_eval_episodes,
                        batch_size=tf_env.batch_size
                        if name == 'train' else 10,
                        obs_dim=obs_dim,
                        start_index=start_index,
                        end_index=end_index,
                        name='SubsetDeltaDistance'),
                ])

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=None)

        train_checkpointer.initialize_or_restore()

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Save the hyperparameters
        operative_filename = os.path.join(root_dir, 'operative.gin')
        with tf.compat.v1.gfile.Open(operative_filename, 'w') as f:
            f.write(gin.operative_config_str())
            logging.info(gin.operative_config_str())

        if replay_buffer.num_frames() == 0:
            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps '
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(sample_batch_size=batch_size,
                                           num_steps=max_future_steps)
        dataset = dataset.unbatch().filter(_filter_invalid_transition)
        dataset = dataset.batch(batch_size, drop_remainder=True)
        goal_fn = functools.partial(c_learning_utils.goal_fn,
                                    batch_size=batch_size,
                                    obs_dim=obs_dim,
                                    gamma=gamma)
        dataset = dataset.map(goal_fn)
        dataset = dataset.prefetch(5)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        global_step_val = global_step.numpy()
        while global_step_val < num_iterations:
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

        return train_loss
Exemple #10
0
def main(_):
  if FLAGS.eager:
    tf.config.experimental_run_functions_eagerly(FLAGS.eager)

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  action_repeat = FLAGS.action_repeat

  _, _, domain_name, _ = FLAGS.env_name.split('-')
  if domain_name in ['cartpole']:
    FLAGS.set_default('action_repeat', 8)
  elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
    FLAGS.set_default('action_repeat', 4)
  elif domain_name in ['finger', 'walker']:
    FLAGS.set_default('action_repeat', 2)

  FLAGS.set_default('max_timesteps', FLAGS.max_timesteps // FLAGS.action_repeat)
  env = utils.load_env(
      FLAGS.env_name, FLAGS.seed, action_repeat, FLAGS.frame_stack)
  eval_env = utils.load_env(
      FLAGS.env_name, FLAGS.seed, action_repeat, FLAGS.frame_stack)
  is_image_obs = (isinstance(env.observation_spec(), TensorSpec) and
                  len(env.observation_spec().shape) == 3)

  spec = (
      env.observation_spec(),
      env.action_spec(),
      env.reward_spec(),
      env.reward_spec(),  # discount spec
      env.observation_spec()  # next observation spec
  )

  replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      spec, batch_size=1, max_length=FLAGS.max_length_replay_buffer)

  @tf.function
  def add_to_replay(state, action, reward, discount, next_states):
    replay_buffer.add_batch((state, action, reward, discount, next_states))

  hparam_str = utils.make_hparam_string(
      FLAGS.xm_parameters, seed=FLAGS.seed, env_name=FLAGS.env_name,
      algo_name=FLAGS.algo_name)
  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'tb', hparam_str))
  results_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'results', hparam_str))

  if 'ddpg' in FLAGS.algo_name:
    model = ddpg.DDPG(
        env.observation_spec(),
        env.action_spec(),
        cross_norm='crossnorm' in FLAGS.algo_name)
  elif 'crr' in FLAGS.algo_name:
    model = awr.AWR(
        env.observation_spec(),
        env.action_spec(), f='bin_max')
  elif 'awr' in FLAGS.algo_name:
    model = awr.AWR(
        env.observation_spec(),
        env.action_spec(), f='exp_mean')
  elif 'sac_v1' in FLAGS.algo_name:
    model = sac_v1.SAC(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0])
  elif 'asac' in FLAGS.algo_name:
    model = asac.ASAC(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0])
  elif 'sac' in FLAGS.algo_name:
    model = sac.SAC(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0],
        cross_norm='crossnorm' in FLAGS.algo_name,
        pcl_actor_update='pc' in FLAGS.algo_name)
  elif 'pcl' in FLAGS.algo_name:
    model = pcl.PCL(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0])

  initial_collect_policy = random_tf_policy.RandomTFPolicy(
      env.time_step_spec(), env.action_spec())

  dataset = replay_buffer.as_dataset(
      num_parallel_calls=tf.data.AUTOTUNE,
      sample_batch_size=FLAGS.sample_batch_size)
  if is_image_obs:
    # Augment images as in DRQ.
    dataset = dataset.map(image_aug,
                          num_parallel_calls=tf.data.AUTOTUNE,
                          deterministic=False).prefetch(3)
  else:
    dataset = dataset.prefetch(3)

  def repack(*data):
    return data[0]
  dataset = dataset.map(repack)
  replay_buffer_iter = iter(dataset)

  previous_time = time.time()
  timestep = env.reset()
  episode_return = 0
  episode_timesteps = 0
  step_mult = 1 if action_repeat < 1 else action_repeat

  for i in tqdm.tqdm(range(FLAGS.max_timesteps)):
    if i % FLAGS.deployment_batch_size == 0:
      for _ in range(FLAGS.deployment_batch_size):
        if timestep.is_last():

          if episode_timesteps > 0:
            current_time = time.time()
            with summary_writer.as_default():
              tf.summary.scalar(
                  'train/returns',
                  episode_return,
                  step=(i + 1) * step_mult)
              tf.summary.scalar(
                  'train/FPS',
                  episode_timesteps / (current_time - previous_time),
                  step=(i + 1) * step_mult)

          timestep = env.reset()
          episode_return = 0
          episode_timesteps = 0
          previous_time = time.time()

        if (replay_buffer.num_frames() < FLAGS.num_random_actions or
            replay_buffer.num_frames() < FLAGS.deployment_batch_size):
          # Use policy only after the first deployment.
          policy_step = initial_collect_policy.action(timestep)
          action = policy_step.action
        else:
          action = model.actor(timestep.observation, sample=True)

        next_timestep = env.step(action)

        add_to_replay(timestep.observation, action, next_timestep.reward,
                      next_timestep.discount, next_timestep.observation)

        episode_return += next_timestep.reward[0]
        episode_timesteps += 1

        timestep = next_timestep

    if i + 1 >= FLAGS.start_training_timesteps:
      with summary_writer.as_default():
        info_dict = model.update_step(replay_buffer_iter)
      if (i + 1) % FLAGS.log_interval == 0:
        with summary_writer.as_default():
          for k, v in info_dict.items():
            tf.summary.scalar(f'training/{k}', v, step=(i + 1) * step_mult)

    if (i + 1) % FLAGS.eval_interval == 0:
      logging.info('Performing policy eval.')
      average_returns, evaluation_timesteps = evaluation.evaluate(
          eval_env, model)

      with results_writer.as_default():
        tf.summary.scalar(
            'evaluation/returns', average_returns, step=(i + 1) * step_mult)
        tf.summary.scalar(
            'evaluation/length', evaluation_timesteps, step=(i+1) * step_mult)
      logging.info('Eval at %d: ave returns=%f, ave episode length=%f',
                   (i + 1) * step_mult, average_returns, evaluation_timesteps)

    if (i + 1) % FLAGS.eval_interval == 0:
      model.save_weights(
          os.path.join(FLAGS.save_dir, 'results',
                       FLAGS.env_name + '__' + str(i + 1)))
Exemple #11
0
def train_eval(root_dir,
               env_name="CartPole-v0",
               agent_class=Agent,
               num_iterations=10000,
               initial_collect_steps=1000,
               collect_steps_per_iteration=1,
               epsilon_greedy=0.1,
               replay_buffer_capacity=10000,
               train_steps_per_iteration=1,
               batch_size=32):
    global_step = tf.compat.v1.train.get_or_create_global_step()
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_py_env = suite_gym.load(env_name)
    network = Network(input_tensor_spec=tf_env.time_step_spec().observation,
                      action_spec=tf_env.action_spec())
    tf_agent = agent_class(time_step_spec=tf_env.time_step_spec(),
                           action_spec=tf_env.action_spec(),
                           network=network,
                           optimizer=tf.compat.v1.train.AdamOptimizer(),
                           epsilon_greedy=epsilon_greedy)
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)
    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)
    replay_observer = [replay_buffer.add_batch]
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer,
        num_steps=initial_collect_steps).run()
    collect_policy = tf_agent.collect_policy
    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer,
        num_steps=collect_steps_per_iteration).run()
    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)

    iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
    experience, _ = iterator.get_next()
    train_op = common.function(tf_agent.train)(experience=experience)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session() as sess:
        sess.run(iterator.initializer)
        common.initialize_uninitialized_variables(sess)

        sess.run(init_agent_op)
        sess.run(initial_collect_op)

        global_step_val = sess.run(global_step)

        collect_call = sess.make_callable(collect_op)
        global_step_call = sess.make_callable(global_step)
        train_step_call = sess.make_callable(train_op)

        for _ in range(num_iterations):
            collect_call()
            for _ in range(train_steps_per_iteration):
                loss_info_value, _ = train_step_call()

            global_step_val = global_step_call()
            logging.info("step = %d, loss = %d", global_step_val,
                         loss_info_value)
Exemple #12
0
def train(
        root_dir,
        load_root_dir=None,
        env_load_fn=None,
        env_name=None,
        num_parallel_environments=1,  # pylint: disable=unused-argument
        agent_class=None,
        initial_collect_random=True,  # pylint: disable=unused-argument
        initial_collect_driver_class=None,
        collect_driver_class=None,
        num_global_steps=1000000,
        train_steps_per_iteration=1,
        train_metrics=None,
        # Safety Critic training args
        train_sc_steps=10,
        train_sc_interval=300,
        online_critic=False,
        # Params for eval
        run_eval=False,
        num_eval_episodes=30,
        eval_interval=1000,
        eval_metrics_callback=None,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        keep_rb_checkpoint=False,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        early_termination_fn=None,
        env_metric_factories=None):  # pylint: disable=unused-argument
    """A simple train and eval for SC-SAC."""

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    train_metrics = train_metrics or []

    if run_eval:
        eval_dir = os.path.join(root_dir, 'eval')
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = env_load_fn(env_name)
        if not isinstance(tf_env, tf_py_environment.TFPyEnvironment):
            tf_env = tf_py_environment.TFPyEnvironment(tf_env)

        if run_eval:
            eval_py_env = env_load_fn(env_name)
            eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        print('obs spec:', observation_spec)
        print('action spec:', action_spec)

        if online_critic:
            resample_metric = tf_py_metric.TfPyMetric(
                py_metrics.CounterMetric('unsafe_ac_samples'))
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   train_step_counter=global_step,
                                   resample_metric=resample_metric)
        else:
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   train_step_counter=global_step)

        tf_agent.initialize()

        # Make the replay buffer.
        collect_data_spec = tf_agent.collect_data_spec

        logging.info('Allocating replay buffer ...')
        # Add to replay buffer and other agent specific observers.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            collect_data_spec, max_length=1000000)
        logging.info('RB capacity: %i', replay_buffer.capacity)
        logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec)

        agent_observers = [replay_buffer.add_batch]
        if online_critic:
            online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                collect_data_spec, max_length=10000)

            online_rb_ckpt_dir = os.path.join(train_dir,
                                              'online_replay_buffer')
            online_rb_checkpointer = common.Checkpointer(
                ckpt_dir=online_rb_ckpt_dir,
                max_to_keep=1,
                replay_buffer=online_replay_buffer)

            clear_rb = common.function(online_replay_buffer.clear)
            agent_observers.append(online_replay_buffer.add_batch)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

        if not online_critic:
            eval_policy = tf_agent.policy
        else:
            eval_policy = tf_agent._safe_policy  # pylint: disable=protected-access

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        if not online_critic:
            collect_policy = tf_agent.collect_policy
        else:
            collect_policy = tf_agent._safe_policy  # pylint: disable=protected-access

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        safety_critic_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'safety_critic'),
            safety_critic=tf_agent._safety_critic_network,  # pylint: disable=protected-access
            global_step=global_step)
        rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
        rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir,
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        if load_root_dir:
            load_root_dir = os.path.expanduser(load_root_dir)
            load_train_dir = os.path.join(load_root_dir, 'train')
            misc.load_pi_ckpt(load_train_dir, tf_agent)  # loads tf_agent

        if load_root_dir is None:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()
        safety_critic_checkpointer.initialize_or_restore()

        collect_driver = collect_driver_class(tf_env,
                                              collect_policy,
                                              observers=agent_observers +
                                              train_metrics)

        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        if not rb_checkpointer.checkpoint_exists:
            logging.info('Performing initial collection ...')
            common.function(
                initial_collect_driver_class(tf_env,
                                             initial_collect_policy,
                                             observers=agent_observers +
                                             train_metrics).run)()
            last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
            logging.info('Data saved after initial collection: %d steps',
                         last_id)
            tf.print(
                replay_buffer._get_rows_for_id(last_id),  # pylint: disable=protected-access
                output_stream=logging.info)

        if run_eval:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics',
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics)
            if FLAGS.viz_pm:
                eval_fig_dir = osp.join(eval_dir, 'figs')
                if not tf.io.gfile.isdir(eval_fig_dir):
                    tf.io.gfile.makedirs(eval_fig_dir)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)
        if online_critic:
            online_dataset = online_replay_buffer.as_dataset(
                num_parallel_calls=3, num_steps=2).prefetch(3)
            online_iterator = iter(online_dataset)

            @common.function
            def critic_train_step():
                """Builds critic training step."""
                experience, buf_info = next(online_iterator)
                if env_name in [
                        'IndianWell', 'IndianWell2', 'IndianWell3',
                        'DrunkSpider', 'DrunkSpiderShort'
                ]:
                    safe_rew = experience.observation['task_agn_rew']
                else:
                    safe_rew = agents.process_replay_buffer(
                        online_replay_buffer, as_tensor=True)
                    safe_rew = tf.gather(safe_rew,
                                         tf.squeeze(buf_info.ids),
                                         axis=1)
                ret = tf_agent.train_sc(experience, safe_rew)
                clear_rb()
                return ret

        @common.function
        def train_step():
            experience, _ = next(iterator)
            ret = tf_agent.train(experience)
            return ret

        if not early_termination_fn:
            early_termination_fn = lambda: False

        loss_diverged = False
        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0
        mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')
        if online_critic:
            mean_resample_ac = tf.keras.metrics.Mean(
                name='mean_unsafe_ac_samples')
            resample_metric.reset()

        while (global_step.numpy() <= num_global_steps
               and not early_termination_fn()):
            # Collect and train.
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            if online_critic:
                mean_resample_ac(resample_metric.result())
                resample_metric.reset()
                if time_step.is_last():
                    resample_ac_freq = mean_resample_ac.result()
                    mean_resample_ac.reset_states()
                    tf.compat.v2.summary.scalar(name='unsafe_ac_samples',
                                                data=resample_ac_freq,
                                                step=global_step)

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
                mean_train_loss(train_loss.loss)

            if online_critic:
                if global_step.numpy() % train_sc_interval == 0:
                    for _ in range(train_sc_steps):
                        sc_loss, lambda_loss = critic_train_step()  # pylint: disable=unused-variable

            total_loss = mean_train_loss.result()
            mean_train_loss.reset_states()
            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    loss_diverged = True
                    break
            else:
                loss_divergence_counter = 0

            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             total_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                safety_critic_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                if online_critic:
                    online_rb_checkpointer.save(global_step=global_step_val)
                rb_checkpointer.save(global_step=global_step_val)

            if run_eval and global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
                if FLAGS.viz_pm:
                    savepath = 'step{}.png'.format(global_step_val)
                    savepath = osp.join(eval_fig_dir, savepath)
                    misc.record_episode_vis_summary(eval_tf_env, eval_policy,
                                                    savepath)

    if not keep_rb_checkpoint:
        misc.cleanup_checkpoints(rb_ckpt_dir)

    if loss_diverged:
        # Raise an error at the very end after the cleanup.
        raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
            total_loss, global_step.numpy()))

    return total_loss
Exemple #13
0
    ## Driver and Observer
    # Driver is responsible collecting trajectories from the environment
    trainMetrics = [
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric()
    ]

    collectDriver = dynamic_step_driver.DynamicStepDriver(
        trainEnv,
        tfAgent.collect_policy,
        observers=[replayBufferObserver] + trainMetrics,
        num_steps=5048)

    ## Initialize Training Environment
    # Utilizing a random policy
    initialCollectPolicy = random_tf_policy.RandomTFPolicy(
        trainEnv.time_step_spec(), trainEnv.action_spec())

    initDriver = dynamic_step_driver.DynamicStepDriver(
        trainEnv,
        initialCollectPolicy,
        observers=[replayBuffer.add_batch],
        num_steps=int(2e4))

    finalTimeStep, finalPolicyState = initDriver.run()

    dataset = replayBuffer.as_dataset(sample_batch_size=64,
                                      num_steps=2,
                                      num_parallel_calls=3).prefetch(3)

    tfAgent = train_agent(int(1.5e6), tfAgent, trainEnv, collectDriver,
                          trainMetrics, dataset)
Exemple #14
0
def train():
    num_iterations=1000000
    # Params for networks.
    actor_fc_layers=(128, 64)
    actor_output_fc_layers=(64,)
    actor_lstm_size=(32,)
    critic_obs_fc_layers=None
    critic_action_fc_layers=None
    critic_joint_fc_layers=(128,)
    critic_output_fc_layers=(64,)
    critic_lstm_size=(32,)
    num_parallel_environments=1
    # Params for collect
    initial_collect_episodes=1
    collect_episodes_per_iteration=1
    replay_buffer_capacity=1000000
    # Params for target update
    target_update_tau=0.05
    target_update_period=5
    # Params for train
    train_steps_per_iteration=1
    batch_size=256
    critic_learning_rate=3e-4
    train_sequence_length=20
    actor_learning_rate=3e-4
    alpha_learning_rate=3e-4
    td_errors_loss_fn=tf.math.squared_difference
    gamma=0.99
    reward_scale_factor=0.1
    gradient_clipping=None
    use_tf_functions=True
    # Params for eval
    num_eval_episodes=30
    eval_interval=10000

    debug_summaries=False
    summarize_grads_and_vars=False

    global_step = tf.compat.v1.train.get_or_create_global_step()
    
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()
    
    actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
        observation_spec,
        action_spec,
        input_fc_layer_params=actor_fc_layers,
        lstm_size=actor_lstm_size,
        output_fc_layer_params=actor_output_fc_layers,
        continuous_projection_net=tanh_normal_projection_network
        .TanhNormalProjectionNetwork)

    critic_net = critic_rnn_network.CriticRnnNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        lstm_size=critic_lstm_size,
        output_fc_layer_params=critic_output_fc_layers,
        kernel_initializer='glorot_uniform',
        last_kernel_initializer='glorot_uniform')
    
    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()
    
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]
    
    env_steps = tf_metrics.EnvironmentSteps(prefix='Train')

    
    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())

    collect_policy = tf_agent.collect_policy
    
    
    initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer,
        num_episodes=initial_collect_episodes)

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=replay_observer,
        num_episodes=collect_episodes_per_iteration)

    if use_tf_functions:
        initial_collect_driver.run = common.function(initial_collect_driver.run)
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)
        
    if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
        logging.info(
          'Initializing replay buffer by collecting experience for %d episodes '
          'with a random policy.', initial_collect_episodes)
        initial_collect_driver.run()


    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)
    
    time_acc = 0
    env_steps_before = env_steps.result().numpy()

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      # Reduce filter_fn over full trajectory sampled. The sequence is kept only
      # if all elements except for the last one pass the filter. This is to
      # allow training on terminal steps.
      return tf.reduce_all(~trajectories.is_boundary()[:-1])
    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size,
        num_steps=train_sequence_length+1).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator = iter(dataset)

    def train_step():
        experience, _ = next(iterator)
        return tf_agent.train(experience)

    if use_tf_functions:
        train_step = common.function(train_step)

    for _ in range(num_iterations):
        start_env_steps = env_steps.result()
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )
        episode_steps = env_steps.result() - start_env_steps
        # TODO(b/152648849)
        for _ in range(episode_steps):
            for _ in range(train_steps_per_iteration):
                train_step()
Exemple #15
0
def train_eval(
    root_dir,
    env_name="HalfCheetah-v2",
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=100000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=500000000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    relabel_type=None,
    num_future_states=4,
    max_episode_steps=100,
    random_seed=0,
    eval_task_list=None,
    constant_task=None,  # Whether to train on a single task
    clip_critic=None,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    if relabel_type == "none":
        relabel_type = None
    assert relabel_type in [None, "future", "last", "soft", "random"]
    if constant_task:
        assert relabel_type is None
    if eval_task_list is None:
        eval_task_list = []
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, "train")
    eval_dir = os.path.join(root_dir, "eval")

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        utils.AverageSuccessMetric(max_episode_steps=max_episode_steps,
                                   buffer_size=num_eval_episodes),
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env, task_distribution = utils.get_env(env_name,
                                                  constant_task=constant_task)
        eval_tf_env, _ = utils.get_env(env_name,
                                       max_episode_steps,
                                       constant_task=constant_task)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=utils.normal_projection_net,
        )
        if isinstance(clip_critic, float):
            output_activation_fn = lambda x: clip_critic * tf.sigmoid(x)
        elif isinstance(clip_critic, tuple):
            assert len(clip_critic) == 2
            min_val, max_val = clip_critic
            output_activation_fn = (
                lambda x:  # pylint: disable=g-long-lambda
                (max_val - min_val) * tf.sigmoid(x) + min_val)
        else:
            output_activation_fn = None
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            output_activation_fn=output_activation_fn,
        )

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = relabelling_replay_buffer.GoalRelabellingReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
            task_distribution=task_distribution,
            actor=actor_net,
            critic=critic_net,
            gamma=gamma,
            relabel_type=relabel_type,
            sample_batch_size=batch_size,
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
            num_future_states=num_future_states,
        )

        env_steps = tf_metrics.EnvironmentSteps(prefix="Train")
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix="Train"),
            env_steps,
            utils.AverageSuccessMetric(
                prefix="Train",
                max_episode_steps=max_episode_steps,
                buffer_size=num_eval_episodes,
            ),
            tf_metrics.AverageReturnMetric(
                prefix="Train",
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix="Train",
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
            ),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, "train_metrics"),
        )
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "policy"),
            policy=eval_policy,
            global_step=global_step,
        )

        train_checkpointer.initialize_or_restore()

        data_collector = utils.DataCollector(
            tf_env,
            tf_agent.collect_policy,
            replay_buffer,
            max_episode_steps=max_episode_steps,
            observers=train_metrics,
        )

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train)
        else:
            tf.config.experimental_run_functions_eagerly(True)

        # Save the config string as late as possible to catch
        # as many object instantiations as possible.
        config_str = gin.operative_config_str()
        logging.info(config_str)
        with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"),
                                     "w") as f:
            f.write(config_str)

        # Collect initial replay data.
        logging.info(
            "Initializing replay buffer by collecting experience for %d steps with "
            "a random policy.",
            initial_collect_steps,
        )
        for _ in range(initial_collect_steps):
            data_collector.step(initial_collect_policy)
        data_collector.reset()
        logging.info("Replay buffer initial size: %d",
                     replay_buffer.num_frames())

        logging.info("Computing initial eval metrics")
        for task in [None] + eval_task_list:
            with utils.FixedTask(eval_tf_env, task):
                prefix = "Metrics" if task is None else "Metrics-%s" % str(
                    task)
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix=prefix,
                )
                metric_utils.log_metrics(eval_metrics)

        time_acc = 0
        env_time_acc = 0
        train_time_acc = 0
        env_steps_before = env_steps.result().numpy()

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train)

        logging.info("Starting training")
        for _ in range(num_iterations):
            start_time = time.time()
            data_collector.step()
            env_time_acc += time.time() - start_time
            train_time_start = time.time()
            for _ in range(train_steps_per_iteration):
                experience = replay_buffer.get_batch()
                train_loss = tf_agent.train(experience)
                total_loss = train_loss.loss
            train_time_acc += time.time() - train_time_start
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info("step = %d, loss = %f", global_step.numpy(),
                             total_loss)

                combined_steps_per_sec = (env_steps.result().numpy() -
                                          env_steps_before) / time_acc
                train_steps_per_sec = (env_steps.result().numpy() -
                                       env_steps_before) / train_time_acc
                env_steps_per_sec = (env_steps.result().numpy() -
                                     env_steps_before) / env_time_acc
                logging.info(
                    "%.3f combined steps / sec: %.3f env steps/sec, %.3f train steps/sec",
                    combined_steps_per_sec,
                    env_steps_per_sec,
                    train_steps_per_sec,
                )
                tf.compat.v2.summary.scalar(
                    name="combined_steps_per_sec",
                    data=combined_steps_per_sec,
                    step=env_steps.result(),
                )
                tf.compat.v2.summary.scalar(
                    name="env_steps_per_sec",
                    data=env_steps_per_sec,
                    step=env_steps.result(),
                )
                tf.compat.v2.summary.scalar(
                    name="train_steps_per_sec",
                    data=train_steps_per_sec,
                    step=env_steps.result(),
                )
                time_acc = 0
                env_time_acc = 0
                train_time_acc = 0
                env_steps_before = env_steps.result().numpy()

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:

                for task in [None] + eval_task_list:
                    with utils.FixedTask(eval_tf_env, task):
                        prefix = "Metrics" if task is None else "Metrics-%s" % str(
                            task)
                        logging.info(prefix)
                        metric_utils.eager_compute(
                            eval_metrics,
                            eval_tf_env,
                            eval_policy,
                            num_episodes=num_eval_episodes,
                            train_step=global_step,
                            summary_writer=eval_summary_writer,
                            summary_prefix=prefix,
                        )
                        metric_utils.log_metrics(eval_metrics)

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

        return train_loss
def train_eval(
    root_dir,
    experiment_name,  # experiment name
    env_name='carla-v0',
    agent_name='sac',  # agent's name
    num_iterations=int(1e7),
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    model_network_ctor_type='non-hierarchical',  # model net
    input_names=['camera', 'lidar'],  # names for inputs
    mask_names=['birdeye'],  # names for masks
    preprocessing_combiner=tf.keras.layers.Add(),  # takes a flat list of tensors and combines them
    actor_lstm_size=(40,),  # lstm size for actor
    critic_lstm_size=(40,),  # lstm size for critic
    actor_output_fc_layers=(100,),  # lstm output
    critic_output_fc_layers=(100,),  # lstm output
    epsilon_greedy=0.1,  # exploration parameter for DQN
    q_learning_rate=1e-3,  # q learning rate for DQN
    ou_stddev=0.2,  # exploration paprameter for DDPG
    ou_damping=0.15,  # exploration parameter for DDPG
    dqda_clipping=None,  # for DDPG
    exploration_noise_std=0.1,  # exploration paramter for td3
    actor_update_period=2,  # for td3
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=int(1e5),
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    initial_model_train_steps=100000,  # initial model training
    batch_size=256,
    model_batch_size=32,  # model training batch size
    sequence_length=4,  # number of timesteps to train model
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    model_learning_rate=1e-4,  # learning rate for model training
    td_errors_loss_fn=tf.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    num_images_per_summary=1,  # images for each summary
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    gpu_allow_growth=True,  # GPU memory growth
    gpu_memory_limit=None,  # GPU memory limit
    action_repeat=1):  # Name of single observation channel, ['camera', 'lidar', 'birdeye']
  # Setup GPU
  gpus = tf.config.experimental.list_physical_devices('GPU')
  if gpu_allow_growth:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  if gpu_memory_limit:
    for gpu in gpus:
      tf.config.experimental.set_virtual_device_configuration(
          gpu,
          [tf.config.experimental.VirtualDeviceConfiguration(
              memory_limit=gpu_memory_limit)])

  # Get train and eval direction
  root_dir = os.path.expanduser(root_dir)
  root_dir = os.path.join(root_dir, env_name, experiment_name)

  # Get summary writers
  summary_writer = tf.summary.create_file_writer(
      root_dir, flush_millis=summaries_flush_secs * 1000)
  summary_writer.set_as_default()

  # Eval metrics
  eval_metrics = [
      tf_metrics.AverageReturnMetric(
        name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(
        name='AverageEpisodeLengthEvalPolicy',
        buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()

  # Whether to record for summary
  with tf.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    # Create Carla environment
    if agent_name == 'latent_sac':
      py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names+mask_names, action_repeat=action_repeat)
    elif agent_name == 'dqn':
      py_env, eval_py_env = load_carla_env(env_name='carla-v0', discrete=True, obs_channels=input_names, action_repeat=action_repeat)
    else:
      py_env, eval_py_env = load_carla_env(env_name='carla-v0', obs_channels=input_names, action_repeat=action_repeat)

    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
    fps = int(np.round(1.0 / (py_env.dt * action_repeat)))

    # Specs
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    ## Make tf agent
    if agent_name == 'latent_sac':
      # Get model network for latent sac
      if model_network_ctor_type == 'hierarchical':
        model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical
      elif model_network_ctor_type == 'non-hierarchical':
        model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical
      else:
        raise NotImplementedError
      model_net = model_network_ctor(input_names, input_names+mask_names)

      # Get the latent spec
      latent_size = model_net.latent_size
      latent_observation_spec = tensor_spec.TensorSpec((latent_size,), dtype=tf.float32)
      latent_time_step_spec = ts.time_step_spec(observation_spec=latent_observation_spec)

      # Get actor and critic net
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          latent_observation_spec,
          action_spec,
          fc_layer_params=actor_fc_layers,
          continuous_projection_net=normal_projection_net)
      critic_net = critic_network.CriticNetwork(
          (latent_observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers)

      # Build the inner SAC agent based on latent space
      inner_agent = sac_agent.SacAgent(
          latent_time_step_spec,
          action_spec,
          actor_network=actor_net,
          critic_network=critic_net,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=alpha_learning_rate),
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step)
      inner_agent.initialize()

      # Build the latent sac agent
      tf_agent = latent_sac_agent.LatentSACAgent(
          time_step_spec,
          action_spec,
          inner_agent=inner_agent,
          model_network=model_net,
          model_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=model_learning_rate),
          model_batch_size=model_batch_size,
          num_images_per_summary=num_images_per_summary,
          sequence_length=sequence_length,
          gradient_clipping=gradient_clipping,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
          fps=fps)

    else:
      # Set up preprosessing layers for dictionary observation inputs
      preprocessing_layers = collections.OrderedDict()
      for name in input_names:
        preprocessing_layers[name] = Preprocessing_Layer(32,256)
      if len(input_names) < 2:
        preprocessing_combiner = None

      if agent_name == 'dqn':
        q_rnn_net = q_rnn_network.QRnnNetwork(
            observation_spec,
            action_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            input_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers)

        tf_agent = dqn_agent.DqnAgent(
            time_step_spec,
            action_spec,
            q_network=q_rnn_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=1,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=q_learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

      elif agent_name == 'ddpg' or agent_name == 'td3':
        actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork(
          observation_spec,
          action_spec,
          preprocessing_layers=preprocessing_layers,
          preprocessing_combiner=preprocessing_combiner,
          input_fc_layer_params=actor_fc_layers,
          lstm_size=actor_lstm_size,
          output_fc_layer_params=actor_output_fc_layers)

        critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
          (observation_spec, action_spec),
          preprocessing_layers=preprocessing_layers,
          preprocessing_combiner=preprocessing_combiner,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          lstm_size=critic_lstm_size,
          output_fc_layer_params=critic_output_fc_layers)

        if agent_name == 'ddpg':
          tf_agent = ddpg_agent.DdpgAgent(
              time_step_spec,
              action_spec,
              actor_network=actor_rnn_net,
              critic_network=critic_rnn_net,
              actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                  learning_rate=actor_learning_rate),
              critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                  learning_rate=critic_learning_rate),
              ou_stddev=ou_stddev,
              ou_damping=ou_damping,
              target_update_tau=target_update_tau,
              target_update_period=target_update_period,
              dqda_clipping=dqda_clipping,
              td_errors_loss_fn=None,
              gamma=gamma,
              reward_scale_factor=reward_scale_factor,
              gradient_clipping=gradient_clipping,
              debug_summaries=debug_summaries,
              summarize_grads_and_vars=summarize_grads_and_vars)
        elif agent_name == 'td3':
          tf_agent = td3_agent.Td3Agent(
              time_step_spec,
              action_spec,
              actor_network=actor_rnn_net,
              critic_network=critic_rnn_net,
              actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                  learning_rate=actor_learning_rate),
              critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                  learning_rate=critic_learning_rate),
              exploration_noise_std=exploration_noise_std,
              target_update_tau=target_update_tau,
              target_update_period=target_update_period,
              actor_update_period=actor_update_period,
              dqda_clipping=dqda_clipping,
              td_errors_loss_fn=None,
              gamma=gamma,
              reward_scale_factor=reward_scale_factor,
              gradient_clipping=gradient_clipping,
              debug_summaries=debug_summaries,
              summarize_grads_and_vars=summarize_grads_and_vars,
              train_step_counter=global_step)

      elif agent_name == 'sac':
        actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers,
            continuous_projection_net=normal_projection_net)

        critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
          (observation_spec, action_spec),
          preprocessing_layers=preprocessing_layers,
          preprocessing_combiner=preprocessing_combiner,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          lstm_size=critic_lstm_size,
          output_fc_layer_params=critic_output_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_distribution_rnn_net,
            critic_network=critic_rnn_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=tf.math.squared_difference,  # make critic loss dimension compatible
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

      else:
        raise NotImplementedError

    tf_agent.initialize()

    # Get replay buffer
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,  # No parallel environments
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    # Train metrics
    env_steps = tf_metrics.EnvironmentSteps()
    average_return = tf_metrics.AverageReturnMetric(
        buffer_size=num_eval_episodes,
        batch_size=tf_env.batch_size)
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        env_steps,
        average_return,
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
    ]

    # Get policies
    # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    eval_policy = tf_agent.policy
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        time_step_spec, action_spec)
    collect_policy = tf_agent.collect_policy

    # Checkpointers
    train_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'train'),
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
        max_to_keep=2)
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step,
        max_to_keep=2)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)
    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    # Collect driver
    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)
    

    # Optimize the performance by using tf functions
    initial_collect_driver.run = common.function(initial_collect_driver.run)
    collect_driver.run = common.function(collect_driver.run)
    tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    if (env_steps.result() == 0 or replay_buffer.num_frames() == 0):
      logging.info(
          'Initializing replay buffer by collecting experience for %d steps'
          'with a random policy.', initial_collect_steps)
      initial_collect_driver.run()

    if agent_name == 'latent_sac':
      compute_summaries(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        train_step=global_step,
        summary_writer=summary_writer,
        num_episodes=1,
        num_episodes_to_render=1,
        model_net=model_net,
        fps=10,
        image_keys=input_names+mask_names)
    else:
      results = metric_utils.eager_compute(
          eval_metrics,
          eval_tf_env,
          eval_policy,
          num_episodes=1,
          train_step=env_steps.result(),
          summary_writer=summary_writer,
          summary_prefix='Eval',
      )
      metric_utils.log_metrics(eval_metrics)

    # Dataset generates trajectories with shape [Bxslx...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=sequence_length + 1).prefetch(3)
    iterator = iter(dataset)

    # Get train step
    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)
    train_step = common.function(train_step)

    if agent_name == 'latent_sac':
      def train_model_step():
        experience, _ = next(iterator)
        return tf_agent.train_model(experience)
      train_model_step = common.function(train_model_step)

    # Training initializations
    time_step = None
    time_acc = 0
    env_steps_before = env_steps.result().numpy()

    # Start training
    for iteration in range(num_iterations):
      start_time = time.time()

      if agent_name == 'latent_sac' and iteration < initial_model_train_steps:
        train_model_step()
      else:
        # Run collect
        time_step, _ = collect_driver.run(time_step=time_step)

        # Train an iteration
        for _ in range(train_steps_per_iteration):
          train_step()

      time_acc += time.time() - start_time

      # Log training information
      if global_step.numpy() % log_interval == 0:
        logging.info('env steps = %d, average return = %f', env_steps.result(),
                     average_return.result())
        env_steps_per_sec = (env_steps.result().numpy() -
                             env_steps_before) / time_acc
        logging.info('%.3f env steps/sec', env_steps_per_sec)
        tf.summary.scalar(
            name='env_steps_per_sec',
            data=env_steps_per_sec,
            step=env_steps.result())
        time_acc = 0
        env_steps_before = env_steps.result().numpy()

      # Get training metrics
      for train_metric in train_metrics:
        train_metric.tf_summaries(train_step=env_steps.result())

      # Evaluation
      if global_step.numpy() % eval_interval == 0:
        # Log evaluation metrics
        if agent_name == 'latent_sac':
          compute_summaries(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            train_step=global_step,
            summary_writer=summary_writer,
            num_episodes=num_eval_episodes,
            num_episodes_to_render=num_images_per_summary,
            model_net=model_net,
            fps=10,
            image_keys=input_names+mask_names)
        else:
          results = metric_utils.eager_compute(
              eval_metrics,
              eval_tf_env,
              eval_policy,
              num_episodes=num_eval_episodes,
              train_step=env_steps.result(),
              summary_writer=summary_writer,
              summary_prefix='Eval',
          )
          metric_utils.log_metrics(eval_metrics)

      # Save checkpoints
      global_step_val = global_step.numpy()
      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
Exemple #17
0
def train_eval(
        root_dir,
        env_name=ENV_NAME,
        num_iterations=ITERATIONS,
        fc_layer_params=LAYER_PARAMETERS,
        # Parameters for collect
        initial_collect_steps=COLLECT_STEPS,
        collect_steps_per_iteration=COLLECT_STEPS,
        epsilon_greedy=GREEDY,
        replay_buffer_capacity=BUFFER,
        # Parameters for target update
        target_update_tau=TAU,
        target_update_period=UPDATE_PERIOD,
        # Parameters for train
        train_steps_per_iteration=TRAIN_ITERATIONS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARN_RATE,
        gamma=GAMMA,
        reward_scale_factor=SCALE,
        gradient_clipping=GRADIENT,
        # Parameters for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Parameters for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        agent_class=dqn_agent.DqnAgent,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_py_env = suite_gym.load(env_name)

        q_net = q_network.QNetwork(tf_env.time_step_spec().observation,
                                   tf_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        tf_agent = agent_class(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        replay_observer = [replay_buffer.add_batch]
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy
        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        experience, _ = iterator.get_next()
        train_op = common.function(tf_agent.train)(experience=experience)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
                tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            global_step_call = sess.make_callable(global_step)
            train_step_call = sess.make_callable([train_op, summary_ops])

            timed_at_step = global_step_call()
            collect_time = 0
            train_time = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                collect_time += time.time() - start_time
                start_time = time.time()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                train_time += time.time() - start_time

                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info('%.3f steps/sec', steps_per_sec)
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
Exemple #18
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    eval_env_name=None,
    env_load_fn=suite_mujoco.load,
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for SAC."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]
  eval_summary_flush_op = eval_summary_writer.flush()

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    # Create the environment.
    tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    eval_env_name = eval_env_name or env_name
    eval_py_env = env_load_fn(eval_env_name)

    # Get the data specs from the environment
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=normal_projection_net)
    critic_net = critic_network.CriticNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers)

    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    eval_py_policy = py_tf_policy.PyTFPolicy(
        greedy_policy.GreedyPolicy(tf_agent.policy))

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
        tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
    ]

    collect_policy = tf_agent.collect_policy
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())

    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps).run()

    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      return ~trajectories.is_boundary()[0]
    dataset = replay_buffer.as_dataset(
        sample_batch_size=5 * batch_size,
        num_steps=2).apply(tf.data.experimental.unbatch()).filter(
            _filter_invalid_transition).batch(batch_size).prefetch(
                batch_size * 5)
    dataset_iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
    trajectories, unused_info = dataset_iterator.get_next()
    train_op = tf_agent.train(trajectories)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    with tf.compat.v1.Session() as sess:
      # Initialize graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)

      # Initialize training.
      sess.run(dataset_iterator.initializer)
      common.initialize_uninitialized_variables(sess)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())

      global_step_val = sess.run(global_step)

      if global_step_val == 0:
        # Initial eval of randomly initialized policy
        metric_utils.compute_summaries(
            eval_metrics,
            eval_py_env,
            eval_py_policy,
            num_episodes=num_eval_episodes,
            global_step=global_step_val,
            callback=eval_metrics_callback,
            log=True,
        )
        sess.run(eval_summary_flush_op)

        # Run initial collect.
        logging.info('Global step %d: Running initial collect op.',
                     global_step_val)
        sess.run(initial_collect_op)

        # Checkpoint the initial replay buffer contents.
        rb_checkpointer.save(global_step=global_step_val)

        logging.info('Finished initial collect.')
      else:
        logging.info('Global step %d: Skipping initial collect op.',
                     global_step_val)

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops])
      global_step_call = sess.make_callable(global_step)

      timed_at_step = global_step_call()
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        for _ in range(train_steps_per_iteration):
          total_loss, _ = train_step_call()
        time_acc += time.time() - start_time
        global_step_val = global_step_call()
        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val, total_loss.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
              log=True,
          )
          sess.run(eval_summary_flush_op)

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        action_spec = tf_env.action_spec()
        num_actions = action_spec.maximum - action_spec.minimum + 1

        if train_sequence_length > 1:
            q_net = create_recurrent_network(input_fc_layer_params, lstm_size,
                                             output_fc_layer_params,
                                             num_actions)
        else:
            q_net = create_feedforward_network(fc_layer_params, num_actions)
            train_sequence_length = n_step_update

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        return train_loss
Exemple #20
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # TODO(kbanoop): Figure out if it is possible to avoid the with block.
    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):

        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_py_env = suite_gym.load(env_name)

        q_net = q_network.QNetwork(tf_env.time_step_spec().observation,
                                   tf_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate),
            # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec(),
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        global_step = tf.train.get_or_create_global_step()

        replay_observer = [replay_buffer.add_batch]
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer,
            num_steps=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy()
        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)

        iterator = dataset.make_initializable_iterator()
        trajectories, _ = iterator.get_next()
        train_op = tf_agent.train(experience=trajectories,
                                  train_step_counter=global_step)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=tf.contrib.checkpoint.List(train_metrics))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy(),
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])
        summary_op = tf.contrib.summary.all_summary_ops()

        with eval_summary_writer.as_default(), \
             tf.contrib.summary.always_record_summaries():
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            tf.contrib.summary.initialize(session=sess)
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(
                [train_op, summary_op, global_step])

            timed_at_step = sess.run(global_step)
            collect_time = 0
            train_time = 0
            steps_per_second_ph = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                collect_time += time.time() - start_time
                start_time = time.time()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _, global_step_val = train_step_call()
                train_time += time.time() - start_time

                if global_step_val % log_interval == 0:
                    tf.logging.info('step = %d, loss = %f', global_step_val,
                                    loss_info_value.loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    tf.logging.info('%.3f steps/sec' % steps_per_sec)
                    tf.logging.info(
                        'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
def train_eval(
    root_dir,
    environment_name="broken_reacher",
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    initial_collect_steps=10000,
    real_initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    real_collect_interval=10,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    classifier_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=True,
    summarize_grads_and_vars=False,
    train_on_real=False,
    delta_r_warmup=0,
    random_seed=0,
    checkpoint_dir=None,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, "train")
    eval_dir = os.path.join(root_dir, "eval")

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    if environment_name == "broken_reacher":
        get_env_fn = darc_envs.get_broken_reacher_env
    elif environment_name == "half_cheetah_obstacle":
        get_env_fn = darc_envs.get_half_cheetah_direction_env
    elif environment_name.startswith("broken_joint"):
        base_name = environment_name.split("broken_joint_")[1]
        get_env_fn = functools.partial(darc_envs.get_broken_joint_env,
                                       env_name=base_name)
    elif environment_name.startswith("falling"):
        base_name = environment_name.split("falling_")[1]
        get_env_fn = functools.partial(darc_envs.get_falling_env,
                                       env_name=base_name)
    else:
        raise NotImplementedError("Unknown environment: %s" % environment_name)

    eval_name_list = ["sim", "real"]
    eval_env_list = [get_env_fn(mode) for mode in eval_name_list]

    eval_metrics_list = []
    for name in eval_name_list:
        eval_metrics_list.append([
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           name="AverageReturn_%s" % name),
        ])

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env_real = get_env_fn("real")
        if train_on_real:
            tf_env = get_env_fn("real")
        else:
            tf_env = get_env_fn("sim")

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=(
                tanh_normal_projection_network.TanhNormalProjectionNetwork),
        )
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer="glorot_uniform",
            last_kernel_initializer="glorot_uniform",
        )

        classifier = classifiers.build_classifier(observation_spec,
                                                  action_spec)

        tf_agent = darc_agent.DarcAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            classifier=classifier,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            classifier_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=classifier_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
        )
        replay_observer = [replay_buffer.add_batch]

        real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
        )
        real_replay_observer = [real_replay_buffer.add_batch]

        sim_train_metrics = [
            tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"),
            tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"),
            tf_metrics.AverageReturnMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageReturnSim",
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageEpisodeLengthSim",
            ),
        ]
        real_train_metrics = [
            tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"),
            tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"),
            tf_metrics.AverageReturnMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageReturnReal",
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageEpisodeLengthReal",
            ),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(
                sim_train_metrics + real_train_metrics, "train_metrics"),
        )
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "policy"),
            policy=eval_policy,
            global_step=global_step,
        )
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "replay_buffer"),
            max_to_keep=1,
            replay_buffer=(replay_buffer, real_replay_buffer),
        )

        if checkpoint_dir is not None:
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
            assert checkpoint_path is not None
            train_checkpointer._load_status = train_checkpointer._checkpoint.restore(  # pylint: disable=protected-access
                checkpoint_path)
            train_checkpointer._load_status.initialize_or_restore()  # pylint: disable=protected-access
        else:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if replay_buffer.num_frames() == 0:
            initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env,
                initial_collect_policy,
                observers=replay_observer + sim_train_metrics,
                num_steps=initial_collect_steps,
            )
            real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env_real,
                initial_collect_policy,
                observers=real_replay_observer + real_train_metrics,
                num_steps=real_initial_collect_steps,
            )

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + sim_train_metrics,
            num_steps=collect_steps_per_iteration,
        )

        real_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env_real,
            collect_policy,
            observers=real_replay_observer + real_train_metrics,
            num_steps=collect_steps_per_iteration,
        )

        config_str = gin.operative_config_str()
        logging.info(config_str)
        with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"),
                                     "w") as f:
            f.write(config_str)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            real_initial_collect_driver.run = common.function(
                real_initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            real_collect_driver.run = common.function(real_collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if replay_buffer.num_frames() == 0:
            logging.info(
                "Initializing replay buffer by collecting experience for %d steps with "
                "a random policy.",
                initial_collect_steps,
            )
            initial_collect_driver.run()
            real_initial_collect_driver.run()

        for eval_name, eval_env, eval_metrics in zip(eval_name_list,
                                                     eval_env_list,
                                                     eval_metrics_list):
            metric_utils.eager_compute(
                eval_metrics,
                eval_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix="Metrics-%s" % eval_name,
            )
            metric_utils.log_metrics(eval_metrics)

        time_step = None
        real_time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = (replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))
        real_dataset = (real_replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))

        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)
        real_iterator = iter(real_dataset)

        def train_step():
            experience, _ = next(iterator)
            real_experience, _ = next(real_iterator)
            return tf_agent.train(experience, real_experience=real_experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            assert not policy_state  # We expect policy_state == ().
            if (global_step.numpy() % real_collect_interval == 0
                    and global_step.numpy() >= delta_r_warmup):
                real_time_step, policy_state = real_collect_driver.run(
                    time_step=real_time_step,
                    policy_state=policy_state,
                )

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info("step = %d, loss = %f", global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info("%.3f steps/sec", steps_per_sec)
                tf.compat.v2.summary.scalar(name="global_steps_per_sec",
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in sim_train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=sim_train_metrics[:2])
            for train_metric in real_train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=real_train_metrics[:2])

            if global_step_val % eval_interval == 0:
                for eval_name, eval_env, eval_metrics in zip(
                        eval_name_list, eval_env_list, eval_metrics_list):
                    metric_utils.eager_compute(
                        eval_metrics,
                        eval_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix="Metrics-%s" % eval_name,
                    )
                    metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
Exemple #22
0
env.close()

################################################################################
"""
Another way to run the environment is with the use of TF agents.
"""

from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_tf_policy

py_env = suite_gym.load("gym_ctf:ctf-v0")

env = tf_py_environment.TFPyEnvironment(py_env)

# This creates a randomly initialized policy that the agent will follow.
# Similar to just taking random actions in the environment.
policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(),
                                         env.action_spec())

time_step = env.reset()

while not time_step.is_last():
    action_step = policy.action(time_step)
    time_step = env.step(action_step.action)
    py_env.render('human')  # Default for this render is rgb_array.

py_env.close()

################################################################################
Exemple #23
0
def simulate():
    # Set up the environments for the agent to train and test its performance
    envTrain = ComputerSnake.Snake()
    envEval = ComputerSnake.Snake(persistence=True)

    # Convert and wrap in TFPyEnvironment training and evaluation environments
    train_env = tf_py_environment.TFPyEnvironment(envTrain)
    eval_env = tf_py_environment.TFPyEnvironment(envEval)

    # Set up q network with necessary parameters
    fc_layer_params = (100, )
    q_net = q_network.QNetwork(train_env.observation_spec(),
                               train_env.action_spec(),
                               fc_layer_params=fc_layer_params)
    optimizer = tf.compat.v1.train.AdamOptimizer(
        learning_rate=learning_rate)  # look up
    train_step_counter = tf.Variable(0)

    # Set up and initialize the DQN learning agent. It takes in the time_step spec,
    # action spec, the q network, the optimizer, a loss function, and train_step_counter
    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,  # look up
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=train_step_counter)
    agent.initialize()

    # Set up policies the agent can use
    eval_policy = agent.policy
    collect_policy = agent.collect_policy

    # Policy which randomly selects actions for each step
    random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                    train_env.action_spec())

    #Buffer to store previous states
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=replay_buffer_max_length)

    # Dataset generates trajectories with shape [Bx2x...] This is so that the agent has access to both the current
    # and previous state to compute loss. Parallel calls and prefetching are used to optimize process.
    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)
    iterator = iter(dataset)

    # (Optional) Optimize by wrapping some of the code in a graph using TF function.
    agent.train = common.function(agent.train)

    # Reset the train step
    agent.train_step_counter.assign(0)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)

    # We initially fill the replay buffer with 100 trajectories to help the assistant
    collect_data(train_env, random_policy, replay_buffer, steps=5000)
    train_env.reset()

    # Here, we run the simulation to train the agent
    scores_list = []
    num_steps_arr = []
    for currStep in range(num_iterations):
        # Collect a few steps using collect_policy and save to the replay buffer.
        for _ in range(collect_steps_per_iteration):
            collect_step(train_env, agent.collect_policy, replay_buffer)

        # Sample a batch of data from the buffer and update the agent's network.
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss

        # Number of training steps so far
        step = agent.train_step_counter.numpy()

        # Prints every 1000 steps made by the training agent
        if step % log_interval == 0:
            print('Moves made = {0}'.format(step))

        # Evaluates the agent's policy every 5000 steps, prints results,
        # ands saves the results for later so they can be plotted
        if step % eval_interval == 0:
            avg_return = 0
            for i in range(num_eval_episodes):
                curr_return = compute_avg_return(eval_env, agent.policy, 1)
                scores_list.append(curr_return)
                num_steps_arr.append(currStep)
                avg_return += curr_return
            avg_return = avg_return / num_eval_episodes
            print('step = {0}: Average Return = {1}'.format(step, avg_return))
    plt.scatter(num_steps_arr, scores_list)
    plt.xlabel('Number of Steps Trained')
    plt.ylabel('Score')
    plt.title('Snake Reinforcement Learning')
    plt.show()
Exemple #24
0
def train_eval(
        root_dir,
        env_name='gym_orbital_system:solarsystem-v0',
        eval_env_name=None,
        env_load_fn=suite_gym.load,
        # The SAC paper reported:
        # Hopper and Cartpole results up to 1000000 iters,
        # Humanoid results up to 10000000 iters,
        # Other mujoco tasks up to 3000000 iters.
        num_iterations=3000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
        # HalfCheetah and Ant take 10000 initial collection steps.
        # Other mujoco tasks take 1000.
        # Different choices roughly keep the initial episodes about the same.
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=50000,
        policy_checkpoint_interval=50000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'td3_eval/train')
    eval_dir = os.path.join(root_dir, 'td3_eval/eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        if replay_buffer.num_frames() == 0:
            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps '
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        global_step_val = global_step.numpy()
        while global_step_val < num_iterations:
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step_val)
                metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
Exemple #25
0
def main():
    # Create train and evaluation environments for Tensorflow
    train_py_env = Environment.Environment()
    train_env = tf_py_environment.TFPyEnvironment(train_py_env)

    eval_py_env = Environment.Environment()
    eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)


    # utils.validate_py_environment(train_py_env, episodes=5)

    # Set up an agent
    # Decide on layers of a network
    fc_layer_params = (50, 200, 25, 6)
    #conv_layer_params = [(4, 4, 1), (8, 4, 2)]
    # QNetwork predicts QValues (expected returns) for all actions based on observation on the given environment
    q_net = q_network.QNetwork(train_env.observation_spec(),
                               train_env.action_spec(),
                               #conv_layer_params=conv_layer_params,
                               fc_layer_params=fc_layer_params)
    # Initialize DQN Agent on the train environment steps, actions, QNetwork, Adam Optimizer, loss function & train step counter
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    # Variable maintains shared, persistent state manipulated by a program.
    # 0 is the initial value.
    # After construction, the type and shape of the variable are fixed.
    train_step_counter = tf.Variable(0)
    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        epsilon_greedy=0.4,  #TODO tune this
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=train_step_counter,
        #boltzmann_temperature=0.1,
        summarize_grads_and_vars=True)

    agent.initialize()

    # Policies

    """A policy defines the way an agent acts in an environment. 
    Typically, the goal of RL is to train the underlying model until the policy produces the desired outcome.
    
    Agents contain two policies:
    agent.policy — The main policy that is used for evaluation and deployment.
    agent.collect_policy — A second policy that is used for data collection.
    """

    # tf_agents.policies.random_tf_policy creates a policy which will randomly select an action for each time_step (independent of agent)
    random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec())

    # Baseline average return of the moves based on random_policy (random actions of an agent)
    print(compute_avg_return(eval_env, random_policy, num_eval_episodes))

    # Replay buffer

    # The replay buffer keeps track of data collected from the environment.
    # This tutorial uses tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer, as it is the most common.
    # The constructor requires the specs for the data it will be collecting.
    # This is available from the agent using the collect_data_spec method.
    # The batch size and maximum buffer length are also required.

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_max_length,
        dataset_window_shift=1)

    # The agent needs access to the replay buffer.
    # This is provided by creating an iterable tf.data.Dataset pipeline which will feed data to the agent.
    # Each row of the replay buffer only stores a single observation step.
    # But since the DQN Agent needs both the current and next observation to compute the loss,
    # the dataset pipeline will sample two adjacent rows for each item in the batch (num_steps=2).
    # This dataset is also optimized by running parallel calls and prefetching data.

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        single_deterministic_pass=False,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    # Train the agent

    agent.train = common.function(agent.train)

    # Reset the train step
    agent.train_step_counter.assign(0)

    collect_data(train_env, random_policy, replay_buffer, steps=10000)

    for _ in range(num_iterations):

        # Collect a few steps using collect_policy and save to the replay buffer.
        for _ in range(collect_steps_per_iteration):
            collect_step(train_env, agent.collect_policy, replay_buffer)

        # Sample a batch of data from the buffer and update the agent's network.
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss

        step = agent.train_step_counter.numpy()
        if step % log_interval == 0:
            avg_return = compute_avg_return(eval_env, agent.policy, 3) #TODO
            if not os.path.exists("eval_data"):
                os.makedirs("eval_data")
            path = os.path.join("eval_data", f'Eval_data.step{step // log_interval}.txt')
            with open(path, 'w') as f:
                for move in eval_py_env.all_moves:
                    print(str(move), file=f)
            eval_py_env.all_moves = []
            print('step = {0}: loss = {1}, Average Return: {2}'.format(step, train_loss, avg_return))
def test(binance, model):

    symbol = "BTCUSDT"

    df = pd.read_csv("..\\Data\\" + symbol + "_data.csv",
                     index_col=0,
                     parse_dates=True)
    df = make_dataset.make_reinforcement_dataset(df)

    train_env_py = TradingEnv(df)
    train_env_py.set_init_balance(1000)
    eval_env_py = TradingEnv(df)
    eval_env_py.set_init_balance(1000)
    #train_env_py = suite_gym.load('CartPole-v0')
    #eval_env_py = suite_gym.load('CartPole-v0')

    train_env = tf_py_environment.TFPyEnvironment(train_env_py)
    eval_env = tf_py_environment.TFPyEnvironment(eval_env_py)

    #q_net = q_rnn_network.QRnnNetwork(
    #    train_env.observation_spec(),
    #    train_env.action_spec(),
    #    input_fc_layer_params=(128,64,16),
    #    output_fc_layer_params=(128,64,16),
    #    lstm_size=(128,64,16))

    q_net = q_network.QNetwork(train_env.observation_spec(),
                               train_env.action_spec(),
                               fc_layer_params=(256, 128, 64, 32, 16))

    global_step = tf.Variable(0, name="global_step", trainable=False)

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001)

    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_step,
        epsilon_greedy=0.2)

    agent.initialize()

    eval_policy = agent.policy
    collect_policy = agent.collect_policy

    random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                    train_env.action_spec())

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=replay_buffer_max_length)

    collect_data(train_env, random_policy, replay_buffer, steps=10000)

    policy_checkpointer = common.Checkpointer(
        ckpt_dir="ReinforcementLearnData/Checkpoint",
        agent=agent,
        policy=agent.policy,
        replay_buffer=replay_buffer,
        global_step=global_step)
    policy_checkpointer.initialize_or_restore()
    tf_policy_saver = policy_saver.PolicySaver(agent.policy)

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)

    iterator = iter(dataset)

    print("1:Train 2:Evaluate 3:Simulation")
    mode = input(">")
    if mode == "1":
        pass
    elif mode == "2":
        evaluate(agent, eval_env)
        return
    elif mode == "3":
        simulation(binance, agent)
        return

    # (Optional) Optimize by wrapping some of the code in a graph using TF function.
    agent.train = common.function(agent.train)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    returns = [avg_return]

    for _ in range(num_iterations):

        # Collect a few steps using collect_policy and save to the replay buffer.
        for _ in range(collect_steps_per_iteration):
            collect_step(train_env, agent.collect_policy, replay_buffer)

        # Sample a batch of data from the buffer and update the agent's network.
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss

        step = agent.train_step_counter.numpy()

        if step % log_interval == 0:
            print("step = {0}: loss = {1}".format(step, train_loss))

        if step % eval_interval == 0:
            avg_return = compute_avg_return(eval_env, agent.policy,
                                            num_eval_episodes)
            print("step = {0}: Average Return = {1}".format(step, avg_return))
            returns.append(avg_return)

    #save agent
    policy_checkpointer.save(global_step)
    tf_policy_saver.save("ReinforcementLearnData/Policy")

    x = range(0, num_iterations + 1, eval_interval)
    plt.plot(x, returns)
    plt.ylabel('Average Return')
    plt.xlabel('Iterations')
    plt.show()

    print("Result:")
    print(compute_avg_return(eval_env, agent.policy, num_eval_episodes))
Exemple #27
0
def train_level(level,
                consecutive_wins_flag=5,
                collect_random_steps=True,
                max_iterations=num_iterations):
    """
    create DQN agent to train a level of the game
    :param level: level of the game
    :param consecutive_wins_flag: number of consecutive wins in evaluation
    signifying the training is done
    :param collect_random_steps: whether to collect random steps at the beginning,
    always set to 'True' when the global step is 0.
    :param max_iterations: stop the training when it reaches the max iteration
    regardless of the result
    """
    global saving_time
    cells = query_level(level)
    size = len(cells)
    env = tf_py_environment.TFPyEnvironment(GameEnv(size, cells))
    eval_env = tf_py_environment.TFPyEnvironment(GameEnv(size, cells))

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    fc_layer_params = (neuron_num_mapper[size], )

    q_net = q_network.QNetwork(env.observation_spec()[0],
                               env.action_spec(),
                               fc_layer_params=fc_layer_params,
                               activation_fn=tf.keras.activations.relu)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    agent = dqn_agent.DdqnAgent(
        env.time_step_spec(),
        env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_step,
        observation_and_action_constraint_splitter=GameEnv.
        obs_and_mask_splitter)
    agent.initialize()

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=env.batch_size,
        max_length=replay_buffer_max_length)

    # drivers
    collect_driver = dynamic_step_driver.DynamicStepDriver(
        env,
        policy=agent.collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=collect_steps_per_iteration)

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    eval_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        eval_env,
        policy=agent.policy,
        observers=eval_metrics,
        num_episodes=num_eval_episodes)

    # checkpointer of the replay buffer and policy
    train_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        dir_path, 'trained_policies/train_lv{0}'.format(level)),
                                             max_to_keep=1,
                                             agent=agent,
                                             policy=agent.policy,
                                             global_step=global_step,
                                             replay_buffer=replay_buffer)

    # policy saver
    tf_policy_saver = policy_saver.PolicySaver(agent.policy)

    train_checkpointer.initialize_or_restore()

    # optimize by wrapping some of the code in a graph using TF function
    agent.train = common.function(agent.train)
    collect_driver.run = common.function(collect_driver.run)
    eval_driver.run = common.function(eval_driver.run)

    # collect initial replay data
    if collect_random_steps:
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec=env.time_step_spec(),
            action_spec=env.action_spec(),
            observation_and_action_constraint_splitter=GameEnv.
            obs_and_mask_splitter)

        dynamic_step_driver.DynamicStepDriver(
            env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)
    iterator = iter(dataset)

    # train the model until 5 consecutive evaluation have reward greater than 100
    consecutive_eval_win = 0
    train_iterations = 0
    while consecutive_eval_win < consecutive_wins_flag and train_iterations < max_iterations:
        collect_driver.run()

        for _ in range(collect_steps_per_iteration):
            experience, _ = next(iterator)
            train_loss = agent.train(experience).loss

        # evaluate the training at intervals
        step = global_step.numpy()
        if step % eval_interval == 0:
            eval_driver.run()
            average_return = eval_metrics[0].result().numpy()
            average_len = eval_metrics[1].result().numpy()
            print("level: {0} step: {1} AverageReturn: {2} AverageLen: {3}".
                  format(level, step, average_return, average_len))

            # evaluate consecutive wins
            if average_return > 10:
                consecutive_eval_win += 1
            else:
                consecutive_eval_win = 0

        if step % save_interval == 0:
            start = time.time()
            train_checkpointer.save(global_step=step)
            saving_time += time.time() - start

        train_iterations += 1

    # save the policy
    train_checkpointer.save(global_step=global_step.numpy())
    tf_policy_saver.save(
        os.path.join(dir_path, 'trained_policies/policy_lv{0}'.format(level)))
def main(
    grid_size=options.grid_size,
    sensing_locations_amount=options.sensing_locations_amount,
    drones_amount=options.drones_amount,
    drone_max_speed=options.drone_max_speed,
    drone_bandwidth=options.drone_bandwidth,
    total_location_data=options.total_location_data,
    cycles_num=options.cycles_num,
    random=options.random,
):
    sensing_locations = (np.random.rand(sensing_locations_amount, 2) * grid_size) - (
        grid_size / 2
    )
    options.sensing_locations = sensing_locations
    options.sensing_locations_amount = sensing_locations_amount
    options.grid_size = grid_size
    options.drones_amount = drones_amount
    options.drone_max_speed = drone_max_speed
    options.drone_bandwidth = drone_bandwidth
    options.cycles_num = cycles_num
    options.random = random

    # -- BEGIN of the jupyter notebook's code (please refer to it for a more readable version)
    # Link to the online notebook: https://github.com/AlessioLuciani/distributed-uav-rl-protocol/blob/main/simulation.ipynb

    options.drones_locations = np.zeros(
        (options.drones_amount, 2), dtype=float)
    options.sensing_data_amounts = np.zeros(options.drones_amount, dtype=float)
    options.aois = np.zeros(options.sensing_locations_amount, dtype=int)
    options.chosen_locations = np.zeros(options.drones_amount, dtype=int)
    options.cycle_stages = np.zeros(options.drones_amount, dtype=int)
    options.data_transmission_cycle = options.drone_bandwidth * options.cycle_length

    options.aois_vec = np.zeros(
        (cycles_num, options.sensing_locations_amount), dtype=int
    )
    options.drones_vec = np.zeros(
        (cycles_num, options.drones_amount, 2), dtype=float)
    options.chosen_loc_vec = np.zeros(
        ((cycles_num, options.drones_amount)), dtype=int)
    options.cycle_stages_vec = np.zeros(
        (cycles_num, options.drones_amount), dtype=int)

    class DurpEnv(py_environment.PyEnvironment):
        def __init__(self, drone):
            self._drone = drone
            self._action_spec = array_spec.BoundedArraySpec(
                shape=(),
                dtype=np.int32,
                minimum=0,
                maximum=options.sensing_locations_amount - 1,
                name="action",
            )
            self._observation_spec = array_spec.BoundedArraySpec(
                shape=(2,),
                minimum=(-options.grid_size / 2),
                maximum=(options.grid_size / 2),
                dtype=np.float64,
            )

        def action_spec(self):
            return self._action_spec

        def observation_spec(self):
            return self._observation_spec

        def _reset(self):
            return ts.restart(np.array([0.0, 0.0], dtype=np.float64))

        def _step(self, action):
            chosen_location_index = int(action)
            accumulated_aoi = get_accumulated_aoi(options.current_cycle[0])
            options.aois[chosen_location_index] = options.current_cycle[0]
            new_accumulated_aoi = get_accumulated_aoi(options.current_cycle[0])
            aoi_multiplier = 0.05
            normalized_diff_aoi_component = (
                (
                    1
                    / (
                        1
                        + np.exp(
                            -(accumulated_aoi - new_accumulated_aoi) *
                            aoi_multiplier
                        )
                    )
                )
                - 0.5
            ) * 2.0
            chosen_location = options.sensing_locations[chosen_location_index]
            drone_location = options.drones_locations[self._drone]
            distance = np.linalg.norm(chosen_location - drone_location)
            distance_multiplier = 0.03
            normalized_location_distance = (
                1 - (1 / (1 + np.exp(-distance * distance_multiplier)))
            ) * 2.0
            aoi_weight = 0.7
            distance_weight = 0.3
            reward = (
                normalized_diff_aoi_component * aoi_weight
                + normalized_location_distance * distance_weight
                + np.random.random() * (1.0 - aoi_weight - distance_weight)
            )

            return ts.transition(chosen_location, reward=reward)

    learning_rate = 0.001
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    environments = []
    agents = []
    for drone in range(options.drones_amount):
        durp_env = DurpEnv(drone)
        train_env = tf_py_environment.TFPyEnvironment(durp_env)
        q_net = q_network.QNetwork(
            train_env.observation_spec(), train_env.action_spec()
        )

        train_step_counter = tf.Variable(0)

        agent = dqn_agent.DqnAgent(
            train_env.time_step_spec(),
            train_env.action_spec(),
            q_network=q_net,
            optimizer=optimizer,
            train_step_counter=train_step_counter,
        )

        agent.initialize()
        agents.append(agent)
        environments.append(train_env)

    num_iterations = 5
    intermediate_iterations = 5
    eval_interval = 10
    initial_collect_steps = 1
    collect_steps_per_iteration = 1
    batch_size = 64

    def compute_avg_return(environment, policy, num_episodes=10):
        total_return = 0.0
        for _ in range(num_episodes):
            time_step = environment.reset()
            episode_return = 0.0

            for m in range(10):
                action_step = policy.action(time_step)
                time_step = environment.step(action_step.action)
                episode_return += time_step.reward
            total_return += episode_return

        avg_return = total_return / num_episodes
        return avg_return.numpy()[0]

    def collect_step(environment, policy, buffer, drone):
        time_step = environment.current_time_step()
        action_step = policy.action(time_step)
        next_time_step = environment.step(action_step.action)
        options.drones_locations[drone] = next_time_step.observation
        traj = trajectory.from_transition(
            time_step, action_step, next_time_step)

        buffer.add_batch(traj)

    def collect_data(env, policy, buffer, steps, drone):
        for step in range(1, steps + 1):
            options.current_cycle[0] = step
            collect_step(env, policy, buffer, drone)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agents[0].collect_data_spec, batch_size=environments[0].batch_size
    )

    random_policy = random_tf_policy.RandomTFPolicy(
        environments[0].time_step_spec(), environments[0].action_spec()
    )

    reset_aois()
    reset_drones_locations()

    collect_data(
        environments[0], random_policy, replay_buffer, initial_collect_steps, 0
    )

    reset_aois()
    reset_drones_locations()

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2
    ).prefetch(3)
    iterator = iter(dataset)

    # Reset the train step
    returns = np.zeros(
        (options.drones_amount, (num_iterations // eval_interval) + 1), dtype=np.float64
    )
    for k in range(len(agents)):
        avg_return = compute_avg_return(environments[k], agents[k].policy)
        returns[k][0] = avg_return

    for i in range(num_iterations):
        reset_aois()
        reset_drones_locations()
        if i % 100 == 0:
            print("----------------", i)

        for j in range(intermediate_iterations):
            for k in range(len(agents)):
                agent = agents[k]
                env = environments[k]

                # Collect a few steps using collect_policy and save to the replay buffer.
                collect_data(
                    env,
                    agent.collect_policy,
                    replay_buffer,
                    collect_steps_per_iteration,
                    k,
                )

        for k in range(len(agents)):
            agent = agents[k]

            # Sample a batch of data from the buffer and update the agent's network.
            experience, unused_info = next(iterator)
            train_loss = agent.train(experience).loss
            print("Loss:", train_loss)

        if i % eval_interval == 0:
            for k in range(len(agents)):
                agent = agents[k]
                avg_return = compute_avg_return(environments[k], agent.policy)
                returns[k][(num_iterations // eval_interval)] = avg_return

    time_steps = []
    for drone in range(options.drones_amount):
        time_steps.append(environments[drone].reset())

    reset_aois()
    reset_drones_locations()

    # -- END of the jupyter notebook's code
    # -- Start of the simulation

    for cycle in range(1, options.cycles_num + 1):
        if cycle % 10 == 1:
            print(get_accumulated_aoi(cycle))
            print(options.aois)
            print("----------------")
        for drone in range(options.drones_amount):
            if options.cycle_stages[drone] == 0:
                agent = agents[drone]
                env = environments[drone]
                chosen_location_index = -1
                if random:
                    chosen_location_index = randrange(
                        options.sensing_locations_amount)
                    options.aois[chosen_location_index] = cycle
                else:
                    options.current_cycle[0] = cycle
                    policy_step = agent.policy.action(time_steps[drone]).replace(
                        action=tf.constant(
                            [get_best_action(drone)], dtype=np.int32)
                    )
                    new_step = env.step(policy_step.action)
                    time_steps[drone] = new_step
                    chosen_location_index = int(policy_step.action)
                options.chosen_locations[drone] = chosen_location_index
                options.cycle_stages[drone] = 1
            elif (
                options.drones_locations[drone]
                != options.sensing_locations[options.chosen_locations[drone]]
            ).all():
                traj = get_trajectory(drone)
                new_location = options.drones_locations[drone] + traj
                options.drones_locations[drone] = new_location
            elif options.sensing_data_amounts[drone] == 0.0:
                options.cycle_stages[drone] = 2
                options.sensing_data_amounts[drone] = options.total_location_data
                options.cycle_stages[drone] = 3
            else:
                options.sensing_data_amounts[drone] = np.max(
                    [
                        options.sensing_data_amounts[drone]
                        - options.data_transmission_cycle,
                        0.0,
                    ]
                )
                if options.sensing_data_amounts[drone] == 0.0:
                    options.cycle_stages[drone] = 0
        options.cycle_stages_vec[cycle - 1] = options.cycle_stages
        options.aois_vec[cycle - 1] = [
            get_location_aoi(cycle, index)
            for index in range(options.sensing_locations_amount)
        ]
        options.drones_vec[cycle - 1] = options.drones_locations
        options.chosen_loc_vec[cycle - 1] = options.chosen_locations

    app = QApplication(sys.argv)
    screen_w, screen_h = get_screen_resolution(app)
    simulator = Simulator(screen_w, screen_h)
    simulator.show()
    sys.exit(app.exec())
Exemple #29
0
def train_eval(
        root_dir,
        env_name='gym_solventx-v0',
        eval_env_name='gym_solventx-v0',
        env_load_fn=suite_gym.load,
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        num_parallel_environments=1,
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=_DEFAULT_REWARD_SCALE,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC on Mujoco.

  All hyperparameters come from the original SAC paper
  (https://arxiv.org/pdf/1801.01290.pdf).
  """

    if reward_scale_factor == _DEFAULT_REWARD_SCALE:
        # Use value recommended by https://arxiv.org/abs/1801.01290
        if env_name.startswith('Humanoid'):
            reward_scale_factor = 20.0
        else:
            reward_scale_factor = 5.0

    root_dir = os.path.expanduser(root_dir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # create training environment
        if num_parallel_environments == 1:
            py_env = env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        # create evaluation environment
        eval_env_name = eval_env_name or env_name
        eval_py_env = env_load_fn(eval_env_name)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
        average_return = tf_metrics.AverageReturnMetric(
            prefix='Train',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix='Train'),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='Train',
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, _RUN_DIR, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, _RUN_DIR, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, _RUN_DIR, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps'
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=env_steps.result(),
            summary_writer=summary_writer,
            summary_prefix='Eval',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, env_steps.result())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('env steps = %d, average return = %f',
                             env_steps.result(), average_return.result())
                env_steps_per_sec = (env_steps.result().numpy() -
                                     env_steps_before) / time_acc
                logging.info('%.3f env steps/sec', env_steps_per_sec)
                tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                            data=env_steps_per_sec,
                                            step=env_steps.result())
                time_acc = 0
                env_steps_before = env_steps.result().numpy()

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=env_steps.result())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=env_steps.result(),
                    summary_writer=summary_writer,
                    summary_prefix='Eval',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, env_steps.result())
                metric_utils.log_metrics(eval_metrics)

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
Exemple #30
0
    def train_implementation(self, train_context: core.TrainContext):
        """Tf-Agents Ppo Implementation of the train loop.

        The implementation follows
        https://colab.research.google.com/github/tensorflow/agents/blob/master/tf_agents/colabs/1_dqn_tutorial.ipynb
        """
        assert isinstance(train_context, core.StepsTrainContext)
        dc: core.StepsTrainContext = train_context

        train_env = self._create_env(discount=dc.reward_discount_gamma)
        observation_spec = train_env.observation_spec()
        action_spec = train_env.action_spec()
        timestep_spec = train_env.time_step_spec()

        # SetUp Optimizer, Networks and DqnAgent
        self.log_api('AdamOptimizer', '()')
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=dc.learning_rate)
        self.log_api('QNetwork', '()')
        q_net = q_network.QNetwork(observation_spec,
                                   action_spec,
                                   fc_layer_params=self.model_config.fc_layers)
        self.log_api('DqnAgent', '()')
        tf_agent = dqn_agent.DqnAgent(
            timestep_spec,
            action_spec,
            q_network=q_net,
            optimizer=optimizer,
            td_errors_loss_fn=common.element_wise_squared_loss)

        self.log_api('tf_agent.initialize', f'()')
        tf_agent.initialize()
        self._trained_policy = tf_agent.policy

        # SetUp Data collection & Buffering
        self.log_api('TFUniformReplayBuffer', '()')
        replay_buffer = TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=train_env.batch_size,
            max_length=dc.max_steps_in_buffer)
        self.log_api('RandomTFPolicy', '()')
        random_policy = random_tf_policy.RandomTFPolicy(
            timestep_spec, action_spec)
        self.log_api('replay_buffer.add_batch', '(trajectory)')
        for _ in range(dc.num_steps_buffer_preload):
            self.collect_step(env=train_env,
                              policy=random_policy,
                              replay_buffer=replay_buffer)

        # Train
        tf_agent.train = common.function(tf_agent.train, autograph=False)
        self.log_api(
            'replay_buffer.as_dataset', f'(num_parallel_calls=3, ' +
            f'sample_batch_size={dc.num_steps_sampled_from_buffer}, num_steps=2).prefetch(3)'
        )
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=dc.num_steps_sampled_from_buffer,
            num_steps=2).prefetch(3)
        iter_dataset = iter(dataset)
        self.log_api('for each iteration')
        self.log_api('  replay_buffer.add_batch', '(trajectory)')
        self.log_api('  tf_agent.train', '(experience=trajectory)')
        while True:
            self.on_train_iteration_begin()
            for _ in range(dc.num_steps_per_iteration):
                self.collect_step(env=train_env,
                                  policy=tf_agent.collect_policy,
                                  replay_buffer=replay_buffer)
            trajectories, _ = next(iter_dataset)
            tf_loss_info = tf_agent.train(experience=trajectories)
            self.on_train_iteration_end(tf_loss_info.loss)
            if train_context.training_done:
                break
        return