def _build_metrics(self, buffer_size=10, batch_size=None):
        python_metrics = [
            tf_py_metric.TFPyMetric(
                py_metrics.AverageReturnMetric(buffer_size=buffer_size,
                                               batch_size=batch_size)),
            tf_py_metric.TFPyMetric(
                py_metrics.AverageEpisodeLengthMetric(buffer_size=buffer_size,
                                                      batch_size=batch_size)),
        ]
        if batch_size is None:
            batch_size = 1
        tensorflow_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=buffer_size,
                                           batch_size=batch_size),
            tf_metrics.AverageEpisodeLengthMetric(buffer_size=buffer_size,
                                                  batch_size=batch_size),
        ]

        return python_metrics, tensorflow_metrics
示例#2
0
 def __init__(self,
              environment=None,
              agent=None,
              tracer=None,
              params=None):
   self._params = params or ParameterServer()
   self._eval_metrics = [
     tf_metrics.AverageReturnMetric(
       buffer_size=self._params["ML"]["TFARunner"]["EvaluationSteps", "", 25]),
     tf_metrics.AverageEpisodeLengthMetric(
       buffer_size=self._params["ML"]["TFARunner"]["EvaluationSteps", "", 25])
   ]
   self._agent = agent
   self._agent.set_action_externally = True
   self._summary_writer = None
   self._environment = environment
   self._wrapped_env = tf_py_environment.TFPyEnvironment(
     TFAWrapper(self._environment))
   self.GetInitialCollectionDriver()
   self.GetCollectionDriver()
   self._logger = logging.getLogger()
   self._tracer = tracer or Tracer()
示例#3
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        num_iterations=100000,
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        # Params for collect
        initial_collect_steps=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=100000,
        exploration_noise_std=0.1,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=200,
        batch_size=64,
        actor_update_period=2,
        train_sequence_length=10,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        gamma=0.995,
        reward_scale_factor=1.0,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=10000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []
        environment = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)
        tf_env = tf_py_environment.TFPyEnvironment(environment)
        eval_py_env = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        actor_net = actor_rnn_network.ActorRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers)

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_rnn_network.CriticRnnNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
        )

        global_step = tf.compat.v1.train.get_or_create_global_step()
        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            exploration_noise_std=exploration_noise_std,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            actor_update_period=actor_update_period,
            dqda_clipping=dqda_clipping,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            debug_summaries=debug_summaries,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run(policy_state=policy_state)

        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run(
                policy_state=policy_state)

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        trajectories, unused_info = iterator.get_next()

        train_fn = common.function(tf_agent.train)
        train_op = train_fn(experience=trajectories)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                time_acc += time.time() - start_time

                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
示例#4
0
def train_eval(
        root_dir,
        tf_master='',
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=50,
        rb_checkpoint_interval=200,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
        batched_py_metric.BatchedPyMetric(
            AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=step_metrics)

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session(tf_master) as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss = sess.run(train_op)
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)
def train_eval(
    root_dir,
    environment_name="broken_reacher",
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    initial_collect_steps=10000,
    real_initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    real_collect_interval=10,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    classifier_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=True,
    summarize_grads_and_vars=False,
    train_on_real=False,
    delta_r_warmup=0,
    random_seed=0,
    checkpoint_dir=None,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, "train")
    eval_dir = os.path.join(root_dir, "eval")

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    if environment_name == "broken_reacher":
        get_env_fn = darc_envs.get_broken_reacher_env
    elif environment_name == "half_cheetah_obstacle":
        get_env_fn = darc_envs.get_half_cheetah_direction_env
    elif environment_name == "inverted_pendulum":
        get_env_fn = darc_envs.get_inverted_pendulum_env
    elif environment_name.startswith("broken_joint"):
        base_name = environment_name.split("broken_joint_")[1]
        get_env_fn = functools.partial(darc_envs.get_broken_joint_env,
                                       env_name=base_name)
    elif environment_name.startswith("falling"):
        base_name = environment_name.split("falling_")[1]
        get_env_fn = functools.partial(darc_envs.get_falling_env,
                                       env_name=base_name)
    else:
        raise NotImplementedError("Unknown environment: %s" % environment_name)

    eval_name_list = ["sim", "real"]
    eval_env_list = [get_env_fn(mode) for mode in eval_name_list]

    eval_metrics_list = []
    for name in eval_name_list:
        eval_metrics_list.append([
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           name="AverageReturn_%s" % name),
        ])

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env_real = get_env_fn("real")
        if train_on_real:
            tf_env = get_env_fn("real")
        else:
            tf_env = get_env_fn("sim")

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=(
                tanh_normal_projection_network.TanhNormalProjectionNetwork),
        )
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer="glorot_uniform",
            last_kernel_initializer="glorot_uniform",
        )

        classifier = classifiers.build_classifier(observation_spec,
                                                  action_spec)

        tf_agent = darc_agent.DarcAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            classifier=classifier,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            classifier_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=classifier_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
        )
        replay_observer = [replay_buffer.add_batch]

        real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity,
        )
        real_replay_observer = [real_replay_buffer.add_batch]

        sim_train_metrics = [
            tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"),
            tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"),
            tf_metrics.AverageReturnMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageReturnSim",
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageEpisodeLengthSim",
            ),
        ]
        real_train_metrics = [
            tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"),
            tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"),
            tf_metrics.AverageReturnMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageReturnReal",
            ),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size,
                name="AverageEpisodeLengthReal",
            ),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(
                sim_train_metrics + real_train_metrics, "train_metrics"),
        )
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "policy"),
            policy=eval_policy,
            global_step=global_step,
        )
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, "replay_buffer"),
            max_to_keep=1,
            replay_buffer=(replay_buffer, real_replay_buffer),
        )

        if checkpoint_dir is not None:
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
            assert checkpoint_path is not None
            train_checkpointer._load_status = train_checkpointer._checkpoint.restore(  # pylint: disable=protected-access
                checkpoint_path)
            train_checkpointer._load_status.initialize_or_restore()  # pylint: disable=protected-access
        else:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if replay_buffer.num_frames() == 0:
            initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env,
                initial_collect_policy,
                observers=replay_observer + sim_train_metrics,
                num_steps=initial_collect_steps,
            )
            real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env_real,
                initial_collect_policy,
                observers=real_replay_observer + real_train_metrics,
                num_steps=real_initial_collect_steps,
            )

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + sim_train_metrics,
            num_steps=collect_steps_per_iteration,
        )

        real_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env_real,
            collect_policy,
            observers=real_replay_observer + real_train_metrics,
            num_steps=collect_steps_per_iteration,
        )

        config_str = gin.operative_config_str()
        logging.info(config_str)
        with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"),
                                     "w") as f:
            f.write(config_str)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            real_initial_collect_driver.run = common.function(
                real_initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            real_collect_driver.run = common.function(real_collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if replay_buffer.num_frames() == 0:
            logging.info(
                "Initializing replay buffer by collecting experience for %d steps with "
                "a random policy.",
                initial_collect_steps,
            )
            initial_collect_driver.run()
            real_initial_collect_driver.run()

        for eval_name, eval_env, eval_metrics in zip(eval_name_list,
                                                     eval_env_list,
                                                     eval_metrics_list):
            metric_utils.eager_compute(
                eval_metrics,
                eval_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix="Metrics-%s" % eval_name,
            )
            metric_utils.log_metrics(eval_metrics)

        time_step = None
        real_time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = (replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))
        real_dataset = (real_replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))

        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)
        real_iterator = iter(real_dataset)

        def train_step():
            experience, _ = next(iterator)
            real_experience, _ = next(real_iterator)
            return tf_agent.train(experience, real_experience=real_experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            assert not policy_state  # We expect policy_state == ().
            if (global_step.numpy() % real_collect_interval == 0
                    and global_step.numpy() >= delta_r_warmup):
                real_time_step, policy_state = real_collect_driver.run(
                    time_step=real_time_step,
                    policy_state=policy_state,
                )

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info("step = %d, loss = %f", global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info("%.3f steps/sec", steps_per_sec)
                tf.compat.v2.summary.scalar(name="global_steps_per_sec",
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in sim_train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=sim_train_metrics[:2])
            for train_metric in real_train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=real_train_metrics[:2])

            if global_step_val % eval_interval == 0:
                for eval_name, eval_env, eval_metrics in zip(
                        eval_name_list, eval_env_list, eval_metrics_list):
                    metric_utils.eager_compute(
                        eval_metrics,
                        eval_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix="Metrics-%s" % eval_name,
                    )
                    metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
示例#6
0
def bigtable_collect(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        num_episodes=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    global_step = tf.compat.v1.train.get_or_create_global_step()
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

    q_net = q_network.QNetwork(tf_env.observation_spec(),
                               tf_env.action_spec(),
                               fc_layer_params=fc_layer_params)
    train_sequence_length = n_step_update

    # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
    tf_agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    #INSTANTIATE CBT TABLE AND GCS BUCKET
    credentials = service_account.Credentials.from_service_account_file(
        SERVICE_ACCOUNT_FILE, scopes=SCOPES)
    cbt_table, gcs_bucket = gcp_load_pipeline(args.gcp_project_id,
                                              args.cbt_instance_id,
                                              args.cbt_table_name,
                                              args.bucket_id, credentials)
    max_row_bytes = (4 * np.prod(VISUAL_OBS_SPEC) + 64)
    cbt_batcher = cbt_table.mutations_batcher(flush_count=args.num_episodes,
                                              max_row_bytes=max_row_bytes)

    bigtable_replay_buffer = BigtableReplayBuffer(
        data_spec=tf_agent.collect_data_spec, max_size=replay_buffer_capacity)

    collect_driver = dynamic_episode_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[bigtable_replay_buffer.add_batch] + train_metrics,
        num_episodes=num_episodes)

    # train_checkpointer = common.Checkpointer(
    #     ckpt_dir=train_dir,
    #     agent=tf_agent,
    #     global_step=global_step,
    #     metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    # policy_checkpointer = common.Checkpointer(
    #     ckpt_dir=os.path.join(train_dir, 'policy'),
    #     policy=eval_policy,
    #     global_step=global_step)
    # rb_checkpointer = common.Checkpointer(
    #     ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
    #     max_to_keep=1,
    #     replay_buffer=bigreplay_buffer)

    # train_checkpointer.initialize_or_restore()
    # rb_checkpointer.initialize_or_restore()

    if use_tf_functions:
        # To speed up collect use common.function.
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    for _ in range(num_iterations):
        collect_driver.run()
示例#7
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        num_iterations=100000,
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        # Params for collect
        initial_collect_episodes=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        # Params for train
        train_steps_per_iteration=200,
        batch_size=64,
        train_sequence_length=10,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        td_errors_loss_fn=None,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=True,
        summarize_grads_and_vars=True,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []

        tf_env = tf_py_environment.TFPyEnvironment(
            suite_dm_control.load(env_name,
                                  task_name,
                                  env_wrappers=env_wrappers))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_dm_control.load(env_name,
                                  task_name,
                                  env_wrappers=env_wrappers))

        actor_net = actor_rnn_network.ActorRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers)

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_rnn_network.CriticRnnNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
        )

        tf_agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_episodes)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d episodes '
            'with a random policy.', initial_collect_episodes)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                experience, _ = next(iterator)
                train_loss = tf_agent.train(experience,
                                            train_step_counter=global_step)
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

        return train_loss
    def __init__(self,
                 root_dir,
                 env_load_fn=suite_gym.load,
                 env_name='CartPole-v0',
                 num_parallel_environments=1,
                 agent_class=None,
                 num_eval_episodes=30,
                 write_summaries=True,
                 summaries_flush_secs=10,
                 eval_metrics_callback=None,
                 env_metric_factories=None):
        """Evaluate policy checkpoints as they are produced.

    Args:
      root_dir: Main directory for experiment files.
      env_load_fn: Function to load the environment specified by env_name.
      env_name: Name of environment to evaluate in.
      num_parallel_environments: Number of environments to evaluate on in
        parallel.
      agent_class: TFAgent class to instantiate for evaluation.
      num_eval_episodes: Number of episodes to average evaluation over.
      write_summaries: Whether to write summaries to the file system.
      summaries_flush_secs: How frequently to flush summaries (in seconds).
      eval_metrics_callback: A function that will be called with evaluation
        results for every checkpoint.
      env_metric_factories: An iterable of metric factories. Use this for eval
        metrics that needs access to the evaluated environment. A metric
        factory is a function that takes an eviornment and buffer_size as
        keyword arguments and returns an instance of py_metric.

    Raises:
      ValueError: when num_parallel_environments > num_eval_episodes or
        agent_class is not set
    """
        if not agent_class:
            raise ValueError(
                'The `agent_class` parameter of Evaluator must be set.')
        if num_parallel_environments > num_eval_episodes:
            raise ValueError(
                'num_parallel_environments should not be greater than '
                'num_eval_episodes')

        self._num_eval_episodes = num_eval_episodes
        self._eval_metrics_callback = eval_metrics_callback
        # Flag that controls eval cycle. If set, evaluation will exit eval loop
        # before the max checkpoint number is reached.
        self._terminate_early = False

        # Save root dir to self so derived classes have access to it.
        self._root_dir = os.path.expanduser(root_dir)
        train_dir = os.path.join(self._root_dir, 'train')
        self._eval_dir = os.path.join(self._root_dir, 'eval')

        self._global_step = tf.compat.v1.train.get_or_create_global_step()

        self._env_name = env_name
        if num_parallel_environments == 1:
            eval_env = env_load_fn(env_name)
        else:
            eval_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)

        if isinstance(eval_env, py_environment.PyEnvironment):
            self._eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env)
            self._eval_py_env = eval_env
        else:
            self._eval_tf_env = eval_env
            self._eval_py_env = None  # Can't generically convert to PyEnvironment.

        self._eval_metrics = [
            tf_metrics.AverageReturnMetric(
                buffer_size=self._num_eval_episodes,
                batch_size=self._eval_tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=self._num_eval_episodes,
                batch_size=self._eval_tf_env.batch_size),
        ]
        if env_metric_factories:
            if not self._eval_py_env:
                raise ValueError(
                    'The `env_metric_factories` parameter of Evaluator '
                    'can only be used with a PyEnvironment environment.')
            for metric_factory in env_metric_factories:
                py_metric = metric_factory(environment=self._eval_py_env,
                                           buffer_size=self._num_eval_episodes)
                self._eval_metrics.append(tf_py_metric.TFPyMetric(py_metric))

        if write_summaries:
            self._eval_summary_writer = tf.compat.v2.summary.create_file_writer(
                self._eval_dir, flush_millis=summaries_flush_secs * 1000)
            self._eval_summary_writer.set_as_default()
        else:
            self._eval_summary_writer = None

        environment_specs.set_observation_spec(
            self._eval_tf_env.observation_spec())
        environment_specs.set_action_spec(self._eval_tf_env.action_spec())

        # Agent params configured with gin.
        self._agent = agent_class(self._eval_tf_env.time_step_spec(),
                                  self._eval_tf_env.action_spec())

        self._eval_policy = greedy_policy.GreedyPolicy(self._agent.policy)
        self._eval_policy.action = common.function(self._eval_policy.action)

        # Run the agent on dummy data to force instantiation of the network. Keras
        # doesn't create variables until you first use the layer. This is needed
        # for checkpoint restoration to work.
        dummy_obs = tensor_spec.sample_spec_nest(
            self._eval_tf_env.observation_spec(),
            outer_dims=(self._eval_tf_env.batch_size, ))
        self._eval_policy.action(
            ts.restart(dummy_obs, batch_size=self._eval_tf_env.batch_size),
            self._eval_policy.get_initial_state(self._eval_tf_env.batch_size))

        self._policy_checkpoint = tf.train.Checkpoint(
            policy=self._agent.policy, global_step=self._global_step)
        self._policy_checkpoint_dir = os.path.join(train_dir, 'policy')
示例#9
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_allowlist='position',
        eval_env_name=None,
        num_iterations=1000000,
        # Params for networks.
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        num_parallel_environments=1,
        # Params for collect
        initial_collect_episodes=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        critic_learning_rate=3e-4,
        train_sequence_length=20,
        actor_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for RNN SAC on DM control."""
    root_dir = os.path.expanduser(root_dir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_allowlist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_allowlist=[observations_allowlist])
            ]
        else:
            env_wrappers = []

        env_load_fn = functools.partial(suite_dm_control.load,
                                        task_name=task_name,
                                        env_wrappers=env_wrappers)

        if num_parallel_environments == 1:
            py_env = env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)

        critic_net = critic_rnn_network.CriticRnnNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
        average_return = tf_metrics.AverageReturnMetric(
            prefix='Train',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix='Train'),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='Train',
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=initial_collect_episodes)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
            logging.info(
                'Initializing replay buffer by collecting experience for %d episodes '
                'with a random policy.', initial_collect_episodes)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=env_steps.result(),
            summary_writer=summary_writer,
            summary_prefix='Eval',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, env_steps.result())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            # Reduce filter_fn over full trajectory sampled. The sequence is kept only
            # if all elements except for the last one pass the filter. This is to
            # allow training on terminal steps.
            return tf.reduce_all(~trajectories.is_boundary()[:-1])

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size,
            num_steps=train_sequence_length + 1).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            start_env_steps = env_steps.result()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            episode_steps = env_steps.result() - start_env_steps
            # TODO(b/152648849)
            for _ in range(episode_steps):
                for _ in range(train_steps_per_iteration):
                    train_step()
                time_acc += time.time() - start_time

                if global_step.numpy() % log_interval == 0:
                    logging.info('env steps = %d, average return = %f',
                                 env_steps.result(), average_return.result())
                    env_steps_per_sec = (env_steps.result().numpy() -
                                         env_steps_before) / time_acc
                    logging.info('%.3f env steps/sec', env_steps_per_sec)
                    tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                                data=env_steps_per_sec,
                                                step=env_steps.result())
                    time_acc = 0
                    env_steps_before = env_steps.result().numpy()

                for train_metric in train_metrics:
                    train_metric.tf_summaries(train_step=env_steps.result())

                if global_step.numpy() % eval_interval == 0:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, env_steps.numpy())
                    metric_utils.log_metrics(eval_metrics)

                global_step_val = global_step.numpy()
                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
示例#10
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=1000,
        # TODO(kbanoop): rename to policy_fc_layers.
        actor_fc_layers=(100, ),
        # Params for collect
        collect_episodes_per_iteration=2,
        replay_buffer_capacity=2000,
        # Params for train
        learning_rate=1e-3,
        gradient_clipping=None,
        normalize_returns=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=100,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=100,
        rb_checkpoint_interval=200,
        log_interval=100,
        summary_interval=100,
        summaries_flush_secs=1,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for Reinforce."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_py_env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

        # TODO(kbanoop): Handle distributions without gin.
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers)

        tf_agent = reinforce_agent.ReinforceAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            normalize_returns=normalize_returns,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        experience = replay_buffer.gather_all()
        train_op = tf_agent.train(experience)
        clear_rb_op = replay_buffer.clear()

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            # TODO(sguada) Remove once Periodically can be saved.
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            # Compute evaluation metrics.
            global_step_call = sess.make_callable(global_step)
            global_step_val = global_step_call()
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(train_op)
            clear_rb_call = sess.make_callable(clear_rb_op)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                total_loss = train_step_call()
                clear_rb_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
示例#11
0
def train_eval(
        load_root_dir,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        train_metrics_callback=None,
        # SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        safety_critic_joint_fc_layers=None,
        safety_critic_lr=3e-4,
        safety_critic_bias_init_val=None,
        safety_critic_kernel_scale=None,
        n_envs=None,
        target_safety=0.2,
        fail_weight=None,
        # Params for train
        num_global_steps=10000,
        batch_size=256,
        # Params for eval
        run_eval=False,
        eval_metrics=[],
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        summary_interval=1000,
        monitor_interval=5000,
        summaries_flush_secs=10,
        debug_summaries=False,
        seed=None):

    if isinstance(agent_class, str):
        assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(
            agent_class)
        agent_class = ALGOS.get(agent_class)

    train_ckpt_dir = osp.join(load_root_dir, 'train')
    rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer')

    py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)

    if monitor:
        vid_path = os.path.join(load_root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    if run_eval:
        eval_dir = os.path.join(load_root_dir, 'eval')
        n_envs = n_envs or num_eval_episodes
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(prefix='EvalMetrics',
                                           buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='EvalMetrics',
                buffer_size=num_eval_episodes,
                batch_size=n_envs)
        ] + [
            tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name))
            for m in eval_metrics
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            seeds = [seed * n_envs + i for i in range(n_envs)]
            try:
                eval_tf_env.pyenv.seed(seeds)
            except:
                pass

    global_step = tf.compat.v1.train.get_or_create_global_step()

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)

    critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
        safety_critic_net = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers)
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               safety_critic_network=safety_critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)
    else:
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)

    collect_data_spec = tf_agent.collect_data_spec
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=1000000)
    replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer)

    tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent)
    if agent_class in SAFETY_AGENTS:
        target_safety = target_safety or tf_agent._target_safety
    loaded_train_steps = global_step.numpy()
    logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir,
                 loaded_train_steps)
    global_step.assign(0)
    tf.summary.experimental.set_step(global_step)

    thresholds = [target_safety, 0.5]
    sc_metrics = [
        tf.keras.metrics.AUC(name='safety_critic_auc'),
        tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                        threshold=0.5),
        tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                       thresholds=thresholds),
        tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                        thresholds=thresholds),
        tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                       thresholds=thresholds),
        tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                        thresholds=thresholds)
    ]

    if seed:
        tf.compat.v1.set_random_seed(seed)

    summaries_flush_secs = 10
    timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
    offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp)
    config_saver = gin.tf.GinConfigSaverHook(offline_train_dir,
                                             summarize_config=True)
    tf.function(config_saver.after_create_session)()

    sc_summary_writer = tf.compat.v2.summary.create_file_writer(
        offline_train_dir, flush_millis=summaries_flush_secs * 1000)
    sc_summary_writer.set_as_default()

    if safety_critic_kernel_scale is not None:
        ki = tf.compat.v1.variance_scaling_initializer(
            scale=safety_critic_kernel_scale,
            mode='fan_in',
            distribution='truncated_normal')
    else:
        ki = tf.compat.v1.keras.initializers.VarianceScaling(
            scale=1. / 3., mode='fan_in', distribution='uniform')

    if safety_critic_bias_init_val is not None:
        bi = tf.constant_initializer(safety_critic_bias_init_val)
    else:
        bi = None
    sc_net_off = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=safety_critic_joint_fc_layers,
        kernel_initializer=ki,
        value_bias_initializer=bi,
        name='SafetyCriticOffline')
    sc_net_off.create_variables()
    target_sc_net_off = common.maybe_copy_target_network_with_checks(
        sc_net_off, None, 'TargetSafetyCriticNetwork')
    optimizer = tf.keras.optimizers.Adam(safety_critic_lr)
    sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic')
    sc_checkpointer = common.Checkpointer(
        ckpt_dir=sc_net_off_ckpt_dir,
        safety_critic=sc_net_off,
        target_safety_critic=target_sc_net_off,
        optimizer=optimizer,
        global_step=global_step,
        max_to_keep=5)
    sc_checkpointer.initialize_or_restore()

    resample_counter = py_metrics.CounterMetric('ActionResampleCounter')
    eval_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=actor_net,
        safety_critic_network=sc_net_off,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=True)

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       num_steps=2,
                                       sample_batch_size=batch_size //
                                       2).prefetch(3)
    data = iter(dataset)
    full_data = replay_buffer.gather_all()

    fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool)
    fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, fail_mask), full_data)
    init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data)
    before_fail_mask = tf.roll(fail_mask, [-1], axis=[1])
    after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1])
    before_fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data)
    after_init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, after_init_mask), full_data)

    filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask))
    filter_mask = tf.pad(
        filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]])
    n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy()

    failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec,
        batch_size=1,
        max_length=n_failures,
        dataset_window_shift=1)
    data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask)

    sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3,
                                               sample_batch_size=batch_size //
                                               2,
                                               num_steps=2).prefetch(3)
    neg_data = iter(sc_dataset_neg)

    get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0]
    eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step,
                                after_init_step, get_action)

    losses = []
    mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss')
    target_update = train_utils.get_target_updater(sc_net_off,
                                                   target_sc_net_off)

    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        while global_step.numpy() < num_global_steps:
            pos_experience, _ = next(data)
            neg_experience, _ = next(neg_data)
            exp = data_utils.concat_batches(pos_experience, neg_experience,
                                            collect_data_spec)
            boundary_mask = tf.logical_not(exp.is_boundary()[:, 0])
            exp = nest_utils.fast_map_structure(
                lambda *x: tf.boolean_mask(*x, boundary_mask), exp)
            safe_rew = exp.observation['task_agn_rew'][:, 1]
            if fail_weight:
                weights = tf.where(tf.cast(safe_rew, tf.bool),
                                   fail_weight / 0.5, (1 - fail_weight) / 0.5)
            else:
                weights = None
            train_loss, sc_loss, lam_loss = train_step(
                exp,
                safe_rew,
                tf_agent,
                sc_net=sc_net_off,
                target_sc_net=target_sc_net_off,
                metrics=sc_metrics,
                weights=weights,
                target_safety=target_safety,
                optimizer=optimizer,
                target_update=target_update,
                debug_summaries=debug_summaries)
            global_step.assign_add(1)
            global_step_val = global_step.numpy()
            losses.append(
                (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy()))
            mean_loss(train_loss)
            with tf.name_scope('Losses'):
                tf.compat.v2.summary.scalar(name='sc_loss',
                                            data=sc_loss,
                                            step=global_step_val)
                tf.compat.v2.summary.scalar(name='lam_loss',
                                            data=lam_loss,
                                            step=global_step_val)
                if global_step_val % summary_interval == 0:
                    tf.compat.v2.summary.scalar(name=mean_loss.name,
                                                data=mean_loss.result(),
                                                step=global_step_val)
            if global_step_val % summary_interval == 0:
                with tf.name_scope('Metrics'):
                    for metric in sc_metrics:
                        if len(tf.squeeze(metric.result()).shape) == 0:
                            tf.compat.v2.summary.scalar(name=metric.name,
                                                        data=metric.result(),
                                                        step=global_step_val)
                        else:
                            fmt_str = '_{}'.format(thresholds[0])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[0],
                                step=global_step_val)
                            fmt_str = '_{}'.format(thresholds[1])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[1],
                                step=global_step_val)
                        metric.reset_states()
            if global_step_val % eval_interval == 0:
                eval_sc(sc_net_off, step=global_step_val)
                if run_eval:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='EvalMetrics',
                    )
                    if train_metrics_callback is not None:
                        train_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics)
                    with eval_summary_writer.as_default():
                        for eval_metric in eval_metrics[2:]:
                            eval_metric.tf_summaries(
                                train_step=global_step,
                                step_metrics=eval_metrics[:2])
            if monitor and global_step_val % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                logging.debug(
                    'saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                    global_step_val, ep_len,
                    time.time() - monitor_start)

            if global_step_val % train_checkpoint_interval == 0:
                sc_checkpointer.save(global_step=global_step_val)
示例#12
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for SAC."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=normal_projection_net)
    critic_net = critic_network.CriticNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers)

    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
        tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
    ]

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    logging.info(
        'Initializing replay buffer by collecting experience for %d steps with '
        'a random policy.', initial_collect_steps)
    initial_collect_driver.run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        experience, _ = next(iterator)
        train_loss = tf_agent.train(experience)
      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step.numpy(),
                     train_loss.loss)
        steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step.numpy()
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step.numpy() % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

      global_step_val = global_step.numpy()
      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
    return train_loss
    def __init__(
        self,
        env,
        global_step,
        root_dir,
        step_metrics,
        name='Agent',
        is_environment=False,
        use_tf_functions=True,
        max_steps=250,
        replace_reward=True,
        non_negative_regret=False,
        id_num=0,
        block_budget_weight=0.,

        # Architecture hparams
        use_rnn=True,
        learning_rate=1e-4,
        actor_fc_layers=(32, 32),
        value_fc_layers=(32, 32),
        lstm_size=(128, ),
        conv_filters=8,
        conv_kernel=3,
        scalar_fc=5,
        entropy_regularization=0.,
        xy_dim=None,

        # Training & logging settings
        num_epochs=25,
        num_eval_episodes=5,
        num_parallel_envs=5,
        replay_buffer_capacity=1001,
        debug_summaries=True,
        summarize_grads_and_vars=True,
    ):
        """Initializes agent, replay buffer, metrics, and checkpointing.

    Args:
      env: An AdversarialTfPyEnvironment with specs and advesary specs.
      global_step: A tf variable tracking the global step.
      root_dir: Path to directory where metrics and checkpoints should be saved.
      step_metrics: A list of tf-agents metrics which represent the x-axis
        during training, such as the number of episodes or the number of
        environment steps.
      name: The name of this agent, e.g. 'Adversary'.
      is_environment: If True, will use the adversary specs from the environment
        and construct a network with additional inputs for the adversary.
      use_tf_functions: If True, will use tf.function to wrap the agent's train
        function.
      max_steps: The maximum number of steps the agent is allowed to interact
        with the environment in every data collection loop.
      replace_reward: If False, will not modify the reward stored in the agent's
        trajectories. This means the agent will be trained with the default
        environment reward rather than regret.
      non_negative_regret: If True, will ensure that the regret reward cannot
        be below 0.
      id_num: The ID number of this agent within the population of agents of the
        same type. I.e. this is adversary agent 3.
      block_budget_weight: Weight to place on the adversary's block budget
        reward. Default is 0 for no block budget.
      use_rnn: If True, will use an RNN within the network architecture.
      learning_rate: The learning rate used to initialize the optimizer for this
        agent.
      actor_fc_layers: The number and size of fully connected layers in the
        policy.
      value_fc_layers: The number and size of fully connected layers in the
        critic / value network.
      lstm_size: The number of LSTM cells in the RNN.
      conv_filters: The number of convolution filters.
      conv_kernel: The width of the convolution kernel.
      scalar_fc: The width of the fully-connected layer which inputs a scalar.
      entropy_regularization: Entropy regularization coefficient.
      xy_dim: Certain adversaries take in the current (x,y) position as a
        one-hot vector. In this case, the maximum value for x or y is required
        to create the one-hot representation.
      num_epochs: Number of epochs for computing PPO policy updates.
      num_eval_episodes: Number of evaluation episodes be eval step, used as
        batch size to initialize eval metrics.
      num_parallel_envs: Number of parallel environments used in trainin, used
        as batch size for training metrics and rewards.
      replay_buffer_capacity: Capacity of this agent's replay buffer.
      debug_summaries: Log additional summaries from the PPO agent.
      summarize_grads_and_vars: If True, logs gradient norms and variances in
        PPO agent.
    """
        self.name = name
        self.id = id_num
        self.max_steps = max_steps
        self.is_environment = is_environment
        self.replace_reward = replace_reward
        self.non_negative_regret = non_negative_regret
        self.block_budget_weight = block_budget_weight

        with tf.name_scope(self.name):
            self.optimizer = tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate)

            logging.info('\tCalculating specs and building networks...')
            if is_environment:
                self.time_step_spec = env.adversary_time_step_spec
                self.action_spec = env.adversary_action_spec
                self.observation_spec = env.adversary_observation_spec

                (self.actor_net, self.value_net
                 ) = multigrid_networks.construct_multigrid_networks(
                     self.observation_spec,
                     self.action_spec,
                     use_rnns=use_rnn,
                     actor_fc_layers=actor_fc_layers,
                     value_fc_layers=value_fc_layers,
                     lstm_size=lstm_size,
                     conv_filters=conv_filters,
                     conv_kernel=conv_kernel,
                     scalar_fc=scalar_fc,
                     scalar_name='time_step',
                     scalar_dim=self.observation_spec['time_step'].maximum + 1,
                     random_z=True,
                     xy_dim=xy_dim)
            else:
                self.time_step_spec = env.time_step_spec()
                self.action_spec = env.action_spec()
                self.observation_spec = env.observation_spec()

                (self.actor_net, self.value_net
                 ) = multigrid_networks.construct_multigrid_networks(
                     self.observation_spec,
                     self.action_spec,
                     use_rnns=use_rnn,
                     actor_fc_layers=actor_fc_layers,
                     value_fc_layers=value_fc_layers,
                     lstm_size=lstm_size,
                     conv_filters=conv_filters,
                     conv_kernel=conv_kernel,
                     scalar_fc=scalar_fc)

            self.tf_agent = ppo_clip_agent.PPOClipAgent(
                self.time_step_spec,
                self.action_spec,
                self.optimizer,
                actor_net=self.actor_net,
                value_net=self.value_net,
                entropy_regularization=entropy_regularization,
                importance_ratio_clipping=0.2,
                normalize_observations=False,
                normalize_rewards=False,
                use_gae=True,
                num_epochs=num_epochs,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)
            self.tf_agent.initialize()
            self.eval_policy = self.tf_agent.policy
            self.collect_policy = self.tf_agent.collect_policy

            logging.info('\tAllocating replay buffer ...')
            self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                self.tf_agent.collect_data_spec,
                batch_size=num_parallel_envs,
                max_length=replay_buffer_capacity)
            logging.info('\t\tRB capacity: %i', self.replay_buffer.capacity)
            self.final_reward = tf.zeros(shape=(num_parallel_envs),
                                         dtype=tf.float32)
            self.enemy_max = tf.zeros(shape=(num_parallel_envs),
                                      dtype=tf.float32)

            # Creates train metrics
            self.step_metrics = step_metrics
            self.train_metrics = step_metrics + [
                tf_metrics.AverageEpisodeLengthMetric(
                    batch_size=num_parallel_envs,
                    name=name + '_AverageEpisodeLength')
            ]
            self.eval_metrics = [
                tf_metrics.AverageEpisodeLengthMetric(
                    batch_size=num_eval_episodes,
                    name=name + '_AverageEpisodeLength')
            ]
            if is_environment:
                self.env_train_metric = adversarial_eval.AdversarialEnvironmentScalar(
                    batch_size=num_parallel_envs,
                    name=name + '_AdversaryReward')
                self.env_eval_metric = adversarial_eval.AdversarialEnvironmentScalar(
                    batch_size=num_eval_episodes,
                    name=name + '_AdversaryReward')
            else:
                self.train_metrics.append(
                    tf_metrics.AverageReturnMetric(
                        batch_size=num_parallel_envs,
                        name=name + '_AverageReturn'))
                self.eval_metrics.append(
                    tf_metrics.AverageReturnMetric(
                        batch_size=num_eval_episodes,
                        name=name + '_AverageReturn'))

            self.metrics_group = metric_utils.MetricsGroup(
                self.train_metrics, name + '_train_metrics')
            self.observers = self.train_metrics + [
                self.replay_buffer.add_batch
            ]

            self.train_dir = os.path.join(root_dir, 'train', name, str(id_num))
            self.eval_dir = os.path.join(root_dir, 'eval', name, str(id_num))
            self.train_checkpointer = common.Checkpointer(
                ckpt_dir=self.train_dir,
                agent=self.tf_agent,
                global_step=global_step,
                metrics=self.metrics_group,
            )
            self.policy_checkpointer = common.Checkpointer(
                ckpt_dir=os.path.join(self.train_dir, 'policy'),
                policy=self.eval_policy,
                global_step=global_step)
            self.saved_model = policy_saver.PolicySaver(self.eval_policy,
                                                        train_step=global_step)
            self.saved_model_dir = os.path.join(root_dir, 'policy_saved_model',
                                                name, str(id_num))

            self.train_checkpointer.initialize_or_restore()

            if use_tf_functions:
                self.tf_agent.train = common.function(self.tf_agent.train,
                                                      autograph=False)

            self.total_loss = None
            self.extra_loss = None
            self.loss_divergence_counter = 0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=5e5,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(
            64,
            64,
        ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(6, ),
        output_fc_layer_params=(30, ),

        # Params for collect
        initial_collect_steps=2000,
        collect_steps_per_iteration=6,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=6,
        batch_size=32,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=1,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    clusters = pickle.load(open('clusters.pickle', 'rb'))
    graph = nx.read_gpickle('graph.gpickle')
    print(graph.nodes)
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name,
                           gym_kwargs={
                               'graph': graph,
                               'clusters': clusters
                           }))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name,
                           gym_kwargs={
                               'graph': graph,
                               'clusters': clusters
                           }))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        action_spec = tf_env.action_spec()
        num_actions = action_spec.maximum - action_spec.minimum + 1

        if train_sequence_length > 1:
            q_net = create_recurrent_network(input_fc_layer_params, lstm_size,
                                             output_fc_layer_params,
                                             num_actions)
        else:
            q_net = create_feedforward_network(fc_layer_params, num_actions)
            train_sequence_length = n_step_update
        q_net = GATNetwork(tf_env.observation_spec(), tf_env.action_spec(),
                           graph)
        #time_step = tf_env.reset()
        #q_net(time_step.observation, time_step.step_type)
        #q_net = actor_distribution_network.ActorDistributionNetwork(
        #	tf_env.observation_spec(),
        #	tf_env.action_spec(),
        #	fc_layer_params=fc_layer_params)

        #q_net = QNetwork(tf_env.observation_spec(), tf_env.action_spec(), 30)
        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        #critic_net = ddpg.critic_network.CriticNetwork(
        #(tf_env.observation_spec(), tf_env.action_spec()),
        #observation_fc_layer_params=None,
        #action_fc_layer_params=None,
        #joint_fc_layer_params=(64,64,),
        #kernel_initializer='glorot_uniform',
        #last_kernel_initializer='glorot_uniform')

        #tf_agent = DdpgAgent(tf_env.time_step_spec(),
        #			   tf_env.action_spec(),
        #			   actor_network=q_net,
        #			   critic_network=critic_net,
        #			   actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        #			   critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        #			   ou_stddev=0.0,
        #			   ou_damping=0.0)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
            tf_metrics.MaxReturnMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        print(tf_env.envs[0]._gym_env.best_controllers)
        print(tf_env.envs[0]._gym_env.best_reward)
        tf_env.envs[0]._gym_env.reset()
        centroid_controllers, heuristic_distance = tf_env.envs[
            0]._gym_env.graphCentroidAction()
        # Convert heuristic controllers to actual
        print(centroid_controllers)
        # Assume all clusters same length
        #centroid_controllers.sort()
        #cluster_len = len(clusters[0])
        #for i in range(len(clusters)):
        #	centroid_controllers[i] -= i * cluster_len
        print(centroid_controllers)
        for cont in centroid_controllers:
            (_, reward_final, _, _) = tf_env.envs[0]._gym_env.step(cont)
        best_heuristic = reward_final
        print(tf_env.envs[0]._gym_env.controllers, reward_final)
        return train_loss
示例#15
0
def train_eval(
        root_dir,
        env_name='Blob2d-v1',
        num_iterations=100000,
        train_sequence_length=1,
        collect_steps_per_iteration=1,
        initial_collect_steps=1500,
        replay_buffer_max_length=10000,
        batch_size=64,
        learning_rate=1e-3,
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for QNetwork
        fc_layer_params=(100, ),
        use_tf_functions=False,
        ## train params
        train_steps_per_iteration=1,
        train_checkpoint_interval=1000,
        policy_checkpoint_interval=1000,
        rb_checkpoint_interval=1000,
        n_step_update=1,
        ## Params for Summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

    env = suite_gym.load('Blob2d-v1')

    tf_env = tf_py_environment.TFPyEnvironment(env)

    action_spec = tf_env.action_spec()

    fc_layer_params = (100, )

    q_net = q_network.QNetwork(tf_env.observation_spec(),
                               tf_env.action_spec(),
                               fc_layer_params=fc_layer_params)

    agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_step)
    agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = agent.policy
    collect_policy = agent.collect_policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_max_length)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration)

    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=agent,
                                             global_step=global_step,
                                             metrics=metric_utils.MetricsGroup(
                                                 train_metrics,
                                                 'train_metrics'))
    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=eval_policy,
                                              global_step=global_step)
    rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'replay_buffer'),
                                          max_to_keep=1,
                                          replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())

    logging.info(
        'Initializing replay buffer by collecting experience for %d steps with '
        'a random policy.', initial_collect_steps)
    dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=initial_collect_steps).run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )

    if eval_metrics_callback is not None:
        eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=train_sequence_length +
                                       1).prefetch(3)
    iterator = iter(dataset)

    def train_step():
        experience, _ = next(iterator)
        return agent.train(experience)

    if use_tf_functions:
        train_step = common.function(train_step)

    # Main Training loop.
    for _ in range(num_iterations):
        start_time = time.time()
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )
        for _ in range(train_steps_per_iteration):
            train_loss = train_step()
        time_acc += time.time() - start_time

        if global_step.numpy() % log_interval == 0:
            logging.info('step = %d, loss = %f', global_step.numpy(),
                         train_loss.loss)
            steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
            logging.info('%.3f steps/sec', steps_per_sec)
            tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                        data=steps_per_sec,
                                        step=global_step)
            timed_at_step = global_step.numpy()
            time_acc = 0

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=train_metrics[:2])

        if global_step.numpy() % train_checkpoint_interval == 0:
            train_checkpointer.save(global_step=global_step.numpy())

        if global_step.numpy() % policy_checkpoint_interval == 0:
            policy_checkpointer.save(global_step=global_step.numpy())

        if global_step.numpy() % rb_checkpoint_interval == 0:
            rb_checkpointer.save(global_step=global_step.numpy())

        if global_step.numpy() % eval_interval == 0:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics',
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics)
    return train_loss
示例#16
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        if train_sequence_length > 1:
            q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)
        else:
            q_net = q_network.QNetwork(tf_env.observation_spec(),
                                       tf_env.action_spec(),
                                       fc_layer_params=fc_layer_params)
            train_sequence_length = n_step_update

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        return train_loss
示例#17
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=1000,
        actor_fc_layers=(100, ),
        value_net_fc_layers=(100, ),
        use_value_network=False,
        use_tf_functions=True,
        # Params for collect
        collect_episodes_per_iteration=2,
        replay_buffer_capacity=2000,
        # Params for train
        learning_rate=1e-3,
        gamma=0.9,
        gradient_clipping=None,
        normalize_returns=True,
        value_estimation_loss_coef=0.2,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=100,
        # Params for checkpoints, summaries, and logging
        log_interval=100,
        summary_interval=100,
        summaries_flush_secs=1,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for Reinforce."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers)

        if use_value_network:
            value_net = value_network.ValueNetwork(
                tf_env.time_step_spec().observation,
                fc_layer_params=value_net_fc_layers)

        global_step = tf.compat.v1.train.get_or_create_global_step()
        tf_agent = reinforce_agent.ReinforceAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            value_network=value_net if use_value_network else None,
            value_estimation_loss_coef=value_estimation_loss_coef,
            gamma=gamma,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            normalize_returns=normalize_returns,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            experience = replay_buffer.gather_all()
            return tf_agent.train(experience)

        if use_tf_functions:
            # To speed up collect use TF function.
            collect_driver.run = common.function(collect_driver.run)
            # To speed up train use TF function.
            tf_agent.train = common.function(tf_agent.train)
            train_step = common.function(train_step)

        # Compute evaluation metrics.
        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        # TODO(b/126590894): Move this functionality into eager_compute_summaries
        if eval_metrics_callback is not None:
            eval_metrics_callback(metrics, global_step.numpy())

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            total_loss = train_step()
            replay_buffer.clear()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()
            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                metrics = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                # TODO(b/126590894): Move this functionality into
                # eager_compute_summaries.
                if eval_metrics_callback is not None:
                    eval_metrics_callback(metrics, global_step_val)
示例#18
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=2000000,
    actor_fc_layers=(400, 300),
    critic_obs_fc_layers=(400,),
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(300,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=100000,
    exploration_noise_std=0.1,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    actor_update_period=2,
    actor_learning_rate=1e-4,
    critic_learning_rate=1e-3,
    dqda_clipping=None,
    td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
    gamma=0.995,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for TD3."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
    eval_py_env = suite_mujoco.load(env_name)

    actor_net = actor_network.ActorNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers,
    )

    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
    )

    tf_agent = td3_agent.Td3Agent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        exploration_noise_std=exploration_noise_std,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        actor_update_period=actor_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step,
    )

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    collect_policy = tf_agent.collect_policy
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=initial_collect_steps).run()

    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
    trajectories, unused_info = iterator.get_next()

    train_fn = common.function(tf_agent.train)
    train_op = train_fn(experience=trajectories)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      # TODO(b/126239733): Remove once Periodically can be saved.
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
          log=True,
      )

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops, global_step])

      timed_at_step = sess.run(global_step)
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _, global_step_val = train_step_call()
        time_acc += time.time() - start_time

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val,
                       loss_info_value.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
              log=True,
          )
示例#19
0
def DDPG_Bipedal(root_dir):

    # Setting up directories for results
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train' + '/' + str(run_id))
    eval_dir = os.path.join(root_dir, 'eval' + '/' + str(run_id))
    vid_dir = os.path.join(root_dir, 'vid' + '/' + str(run_id))

    # Set up Summary writer for training and evaluation
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        # Metric to record average return
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        # Metric to record average episode length
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    #Create global step
    global_step = tf.compat.v1.train.get_or_create_global_step()

    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Load Environment with different wrappers
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))
        eval_py_env = suite_gym.load(env_name)

        # Define Actor Network
        actorNN = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=(400, 300),
        )

        # Define Critic Network
        NN_input_specs = (tf_env.time_step_spec().observation,
                          tf_env.action_spec())

        criticNN = critic_network.CriticNetwork(
            NN_input_specs,
            observation_fc_layer_params=(400, ),
            action_fc_layer_params=None,
            joint_fc_layer_params=(300, ),
        )

        # Define & initialize DDPG Agent
        agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actorNN,
            critic_network=criticNN,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
            gamma=gamma,
            train_step_counter=global_step)
        agent.initialize()

        # Determine which train metrics to display with summary writer
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Set policies for evaluation, initial collection
        eval_policy = agent.policy  # Actor policy
        collect_policy = agent.collect_policy  # Actor policy with OUNoise

        # Set up replay buffer
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        # Define driver for initial replay buffer filling
        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,  # Initializes with random Parameters
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps)

        # Define collect driver for collect steps per iteration
        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            agent.train = common.function(agent.train)

        # Make 1000 random steps in tf_env and save in Replay Buffer
        logging.info(
            'Initializing replay buffer by collecting experience for 1000 steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        # Computes Evaluation Metrics
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset outputs steps in batches of 64
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=64,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(
                iterator)  #Get experience from dataset (replay buffer)
            return agent.train(experience)  #Train agent on that experience

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()  # Get start time
            # Collect data for replay buffer
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            # Train on experience
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='iterations_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                metric_utils.log_metrics(eval_metrics)
                if results['AverageReturn'].numpy() >= 230.0:
                    video_score = create_video(video_dir=vid_dir,
                                               env_name="BipedalWalker-v2",
                                               vid_policy=eval_policy,
                                               video_id=global_step.numpy())
    return train_loss
示例#20
0
def train_eval(
    root_dir,
    env_name='CartPole-v0',
    num_iterations=100000,
    fc_layer_params=(100,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    epsilon_greedy=0.1,
    replay_buffer_capacity=100000,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    learning_rate=1e-3,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    agent_class=dqn_agent.DqnAgent,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for DQN."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_py_env = suite_gym.load(env_name)

    q_net = q_network.QNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=fc_layer_params)

    # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
    tf_agent = agent_class(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        epsilon_greedy=epsilon_greedy,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    replay_observer = [replay_buffer.add_batch]
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps).run()

    collect_policy = tf_agent.collect_policy
    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)

    iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
    experience, _ = iterator.get_next()
    train_op = common.function(tf_agent.train)(experience=experience)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
          log=True,
      )

      collect_call = sess.make_callable(collect_op)
      global_step_call = sess.make_callable(global_step)
      train_step_call = sess.make_callable([train_op, summary_ops])

      timed_at_step = global_step_call()
      collect_time = 0
      train_time = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        # Train/collect/eval.
        start_time = time.time()
        collect_call()
        collect_time += time.time() - start_time
        start_time = time.time()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _ = train_step_call()
        train_time += time.time() - start_time

        global_step_val = global_step_call()

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val,
                       loss_info_value.loss)
          steps_per_sec = (
              (global_step_val - timed_at_step) / (collect_time + train_time))
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          logging.info('%.3f steps/sec', steps_per_sec)
          logging.info('%s', 'collect_time = {}, train_time = {}'.format(
              collect_time, train_time))
          timed_at_step = global_step_val
          collect_time = 0
          train_time = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
          )
示例#21
0
def load_agents_and_create_videos(root_dir,
        env_name='CartPole-v0',
        num_iterations=NUM_ITERATIONS,
        max_ep_steps=1000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=((128,64,32)),
        # Params for QRnnNetwork
        input_fc_layer_params=(50,),
        lstm_size=(20,),
        output_fc_layer_params=(20,),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=10000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval 
        num_eval_episodes=10,
        num_random_episodes=1,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None,
        random_metrics_callback=None):
    
    
    
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    random_dir = os.path.join(root_dir, 'random')
    
    eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    
    # Match the environments used in training
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name, max_episode_steps=max_ep_steps))
    eval_py_env = suite_gym.load(env_name, max_episode_steps=max_ep_steps)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    if train_sequence_length != 1 and n_step_update != 1:
        raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

    if train_sequence_length > 1:
        q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)
    else:
        q_net = q_network.QNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=fc_layer_params)

        train_sequence_length = n_step_update

    # Match the agents used in training
    tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
    
    tf_agent.initialize()

    train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),]

    eval_policy = tf_agent.policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

    train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))

    policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=eval_policy,
            global_step=global_step)

    rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

    # Load the data from training
    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    # Define a random policy for comparison
    random_policy = random_tf_policy.RandomTFPolicy(eval_tf_env.time_step_spec(),
                                                    eval_tf_env.action_spec())

    # Make movies of the trained agent and a random agent
    date_string = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
    
    trained_filename = "trained-agent" + date_string
    create_policy_eval_video(eval_tf_env, eval_py_env, tf_agent.policy, trained_filename)

    random_filename = 'random-agent ' + date_string
    create_policy_eval_video(eval_tf_env, eval_py_env, random_policy, random_filename)
示例#22
0
    def train(self,
              training_iterations=TRAINING_ITERATIONS,
              training_stock_list=None):
        self.reset(training_stock_list)

        train_dir = 'training_data_progress/train-' + self.name
        eval_dir = 'training_data_progress/eval-' + self.name

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=self.tf_agent.collect_data_spec,
            batch_size=self.tf_training_env.batch_size,
            max_length=MAX_BUFFER_SIZE)

        summaries_flush_secs = 10

        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=NUM_EVAL_EPISODES),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=NUM_EVAL_EPISODES)
        ]

        global_step = self.tf_agent.train_step_counter
        with tf.compat.v2.summary.record_if(
                lambda: tf.math.equal(global_step % LOG_INTERVAL, 0)):

            replay_observer = [replay_buffer.add_batch]

            train_metrics = [
                tf_metrics.NumberOfEpisodes(),
                tf_metrics.EnvironmentSteps(),
                tf_metrics.AverageReturnMetric(
                    buffer_size=NUM_EVAL_EPISODES,
                    batch_size=self.tf_training_env.batch_size),
                tf_metrics.AverageEpisodeLengthMetric(
                    buffer_size=NUM_EVAL_EPISODES,
                    batch_size=self.tf_training_env.batch_size),
            ]

            eval_policy = greedy_policy.GreedyPolicy(self.tf_agent.policy)
            initial_collect_policy = random_tf_policy.RandomTFPolicy(
                self.tf_training_env.time_step_spec(),
                self.tf_training_env.action_spec())
            collect_policy = self.tf_agent.collect_policy

            train_checkpointer = common.Checkpointer(
                ckpt_dir=train_dir,
                agent=self.tf_agent,
                global_step=global_step,
                metrics=metric_utils.MetricsGroup(train_metrics,
                                                  'train_metrics'))
            policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
                train_dir, 'policy'),
                                                      policy=eval_policy,
                                                      global_step=global_step)
            rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
                train_dir, 'replay_buffer'),
                                                  max_to_keep=1,
                                                  replay_buffer=replay_buffer)

            train_checkpointer.initialize_or_restore()
            rb_checkpointer.initialize_or_restore()

            initial_collect_driver_random = dynamic_step_driver.DynamicStepDriver(
                self.tf_training_env,
                initial_collect_policy,
                observers=replay_observer + train_metrics,
                num_steps=INIT_COLLECT_STEPS)
            initial_collect_driver_random.run = common.function(
                initial_collect_driver_random.run)

            collect_driver = dynamic_step_driver.DynamicStepDriver(
                self.tf_training_env,
                collect_policy,
                observers=replay_observer + train_metrics,
                num_steps=STEP_ITERATIONS)

            collect_driver.run = common.function(collect_driver.run)
            self.tf_agent.train = common.function(self.tf_agent.train)

            # Collect some initial data.
            # Random
            random_policy = random_tf_policy.RandomTFPolicy(
                self.tf_training_env.time_step_spec(),
                self.tf_training_env.action_spec())
            avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return(
                random_policy)
            print(
                'Random:\n\tAverage Return = {0}\n\tAverage Return Per Step = {1}\n\tPercent = {2}%'
                .format(avg_return, avg_return_per_step, avg_daily_percentage))
            self.gym_training_env.save_feature_distribution(self.name)

            # Agent
            avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return(
                self.tf_agent.policy)
            print(
                'Agent :\n\tAverage Return = {0}\n\tAverage Return Per Step = {1}\n\tPercent = {2}%'
                .format(avg_return, avg_return_per_step, avg_daily_percentage))
            self.eval_env.reset()
            self.eval_env.run_and_save_evaluation(str(0))
            self.gym_training_env.save_feature_distribution(self.name)

            evaluations = [self.get_evaluation()]
            returns = [self.eval_env.returns]
            actions_over_time_list = [self.eval_env.action_sets_over_time]

            # Collect initial replay data.
            print(
                'Initializing replay buffer by collecting experience for {} steps with '
                'a random policy.'.format(INIT_COLLECT_STEPS))
            initial_collect_driver_random.run()

            results = metric_utils.eager_compute(
                eval_metrics,
                self.tf_training_env,
                eval_policy,
                num_episodes=NUM_EVAL_EPISODES,
                train_step=global_step,
                summary_prefix='Metrics',
            )
            metric_utils.log_metrics(eval_metrics)

            time_step = None
            policy_state = collect_policy.get_initial_state(
                self.tf_training_env.batch_size)

            timed_at_step = global_step.numpy()
            time_acc = 0

            # Prepare replay buffer as dataset with invalid transitions filtered.
            def _filter_invalid_transition(trajectories, unused_arg1):
                return ~trajectories.is_boundary()[0]

            dataset = replay_buffer.as_dataset(
                sample_batch_size=BATCH_SIZE, num_steps=2).unbatch().filter(
                    _filter_invalid_transition).batch(BATCH_SIZE).prefetch(5)
            # Dataset generates trajectories with shape [Bx2x...]
            iterator = iter(dataset)

            def _train_step():
                try:
                    experience, _ = next(iterator)
                    return self.tf_agent.train(experience)
                except Exception as e:
                    print("Caught Exception:", e)
                    return 1e-20

            train_step = common.function(_train_step)

            for _ in range(training_iterations):
                start_time = time.time()
                time_step, policy_state = collect_driver.run(
                    time_step=time_step,
                    policy_state=policy_state,
                )
                for _ in range(STEP_ITERATIONS):
                    train_loss = train_step()
                time_acc += time.time() - start_time

                self.global_step_val = global_step.numpy()

                if self.global_step_val % LOG_INTERVAL == 0:
                    steps_per_sec = (self.global_step_val -
                                     timed_at_step) / time_acc
                    print(
                        self.name,
                        '\nstep = {0:d}:\n\tloss = {1:f}\n\t{2:.3f} steps/sec'.
                        format(self.global_step_val, train_loss.loss,
                               steps_per_sec))
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)
                    timed_at_step = self.global_step_val
                    time_acc = 0

                for train_metric in train_metrics:
                    train_metric.tf_summaries(train_step=global_step,
                                              step_metrics=train_metrics[:2])

                if self.global_step_val % EVAL_INTERVAL == 0:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        self.tf_training_env,
                        eval_policy,
                        num_episodes=NUM_EVAL_EPISODES,
                        train_step=global_step,
                        summary_prefix='Metrics',
                    )
                    metric_utils.log_metrics(eval_metrics)

                    avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return(
                        self.tf_agent.policy)
                    print(
                        self.name,
                        '\nstep = {0}:\n\tloss = {1}\n\tAverage Return = {2}\n\tAverage Return Per Step = {3}\n\tPercent = {4}%'
                        .format(self.global_step_val, train_loss.loss,
                                avg_return, avg_return_per_step,
                                avg_daily_percentage))
                    self.eval_env.reset()
                    self.eval_env.run_and_save_evaluation(
                        str(self.global_step_val // EVAL_INTERVAL))
                    self.gym_training_env.save_feature_distribution(self.name)

                    if avg_daily_percentage == returns[-1]:
                        "---- Average return did not change since last time. Breaking loop."
                        break

                    evaluations.append(self.get_evaluation())
                    returns.append(self.eval_env.returns)
                    actions_over_time_list.append(
                        self.eval_env.action_sets_over_time)

                    train_checkpointer.save(global_step=self.global_step_val)
                    policy_checkpointer.save(global_step=self.global_step_val)
                    rb_checkpointer.save(global_step=self.global_step_val)

        training_report = util.load_training_report()
        agent_report = training_report.get(self.name, dict())
        agent_report["Training Results"] = returns
        agent_report["Evaluations"] = [max(e, 0.0) for e in evaluations]
        bins = [0.1 * i - 0.0000001 for i in range(11)]
        agent_report["Histograms"] = [
            str(list(map(int,
                         np.histogram(actions, bins, density=True)[0])))
            for actions in actions_over_time_list
        ]
        training_report[self.name] = agent_report
        util.save_training_report(training_report)

        print("---- Average-daily-percentage over training period for",
              self.name)
        print("\t\t", avg_daily_percentage)
        self.save()
        self.reset()
示例#23
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=None,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        lstm_size=(20, ),
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)
        eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None,
                lstm_size=lstm_size)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_clip_agent.PPOClipAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
示例#24
0
def main():
    def compute_avg_return(environment, policy, num_episodes=10):
        total_return = 0.0
        for _ in range(num_episodes):
            time_step = environment.reset()
            episode_return = 0.0

            while not time_step.is_last():
                action_step = policy.action(time_step)
                time_step = environment.step(action_step.action)
                episode_return += time_step.reward
            total_return += episode_return

        avg_return = total_return / num_episodes
        return avg_return.numpy()[0]

    class ShowProgress:
        def __init__(self, total):
            self.counter = 0
            self.total = total

        def __call__(self, trajectory):
            if not trajectory.is_boundary():
                self.counter += 1
            if self.counter % 100 == 0:
                print("\r{}/{}".format(self.counter, self.total), end="")

    def train_agent(n_iterations, save_each=10000, print_each=500):
        time_step = None
        policy_state = agent.collect_policy.get_initial_state(
            tf_env.batch_size)
        iterator = iter(dataset)

        for iteration in range(n_iterations):
            step = agent.train_step_counter.numpy()
            current_metrics = []

            time_step, policy_state = collect_driver.run(
                time_step, policy_state)
            trajectories, buffer_info = next(iterator)

            train_loss = agent.train(trajectories)
            all_train_loss.append(train_loss.loss.numpy())

            for i in range(len(train_metrics)):
                current_metrics.append(train_metrics[i].result().numpy())

            all_metrics.append(current_metrics)

            if iteration % print_each == 0:
                print("\nIteration: {}, loss:{:.2f}".format(
                    iteration, train_loss.loss.numpy()))

                for i in range(len(train_metrics)):
                    print('{}: {}'.format(train_metrics[i].name,
                                          train_metrics[i].result().numpy()))

            if step % EVAL_INTERVAL == 0:
                avg_return = compute_avg_return(eval_tf_env, agent.policy,
                                                NUM_EVAL_EPISODES)
                print(f'Step = {step}, Average Return = {avg_return}')
                returns.append((step, avg_return))

            if step % save_each == 0:
                print("Saving model")
                train_checkpointer.save(train_step)
                policy_save_handler.save("policy")
                with open("checkpoint/train_loss.pickle", "wb") as f:
                    pickle.dump(all_train_loss, f)
                with open("checkpoint/all_metrics.pickle", "wb") as f:
                    pickle.dump(all_metrics, f)
                with open("checkpoint/returns.pickle", "wb") as f:
                    pickle.dump(returns, f)

    eval_tf_env = tf_py_environment.TFPyEnvironment(BombermanEnvironment())

    #tf_env = tf_py_environment.TFPyEnvironment(
    #   parallel_py_environment.ParallelPyEnvironment(
    #       [BombermanEnvironment] * N_PARALLEL_ENVIRONMENTS
    #   ))

    tf_env = tf_py_environment.TFPyEnvironment(BombermanEnvironment())

    q_net = QNetwork(tf_env.observation_spec(),
                     tf_env.action_spec(),
                     conv_layer_params=[(32, 3, 1), (32, 3, 1)],
                     fc_layer_params=[128, 64, 32])

    train_step = tf.Variable(0)
    update_period = 4
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)  # todo fine tune

    epsilon_fn = tf.keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=0.7,
        decay_steps=25000 // update_period,
        end_learning_rate=0.01)

    agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=0.99,
        train_step_counter=train_step,
        epsilon_greedy=lambda: epsilon_fn(train_step))

    agent.initialize()

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=10000)
    replay_buffer_observer = replay_buffer.add_batch

    train_metrics = [
        tf_metrics.AverageReturnMetric(batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(batch_size=tf_env.batch_size)
    ]

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        agent.collect_policy,
        observers=[replay_buffer_observer] + train_metrics,
        num_steps=update_period)

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())

    initial_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=[
            replay_buffer.add_batch,
            ShowProgress(INITIAL_COLLECT_STEPS)
        ],
        num_steps=INITIAL_COLLECT_STEPS)
    final_time_step, final_policy_state = initial_driver.run()

    dataset = replay_buffer.as_dataset(sample_batch_size=64,
                                       num_steps=2,
                                       num_parallel_calls=3).prefetch(3)

    agent.train = common.function(agent.train)

    all_train_loss = []
    all_metrics = []
    returns = []

    checkpoint_dir = "checkpoint/"
    train_checkpointer = common.Checkpointer(ckpt_dir=checkpoint_dir,
                                             max_to_keep=1,
                                             agent=agent,
                                             policy=agent.policy,
                                             replay_buffer=replay_buffer,
                                             global_step=train_step)
    # train_checkpointer.initialize_or_restore()
    # train_step = tf.compat.v1.train.get_global_step()
    policy_save_handler = policy_saver.PolicySaver(agent.policy)

    # training here
    train_agent(2000)

    # save at end in every case

    policy_save_handler.save("policy")
示例#25
0
def train(
        root_dir,
        load_root_dir=None,
        env_load_fn=None,
        env_name=None,
        num_parallel_environments=1,  # pylint: disable=unused-argument
        agent_class=None,
        initial_collect_random=True,  # pylint: disable=unused-argument
        initial_collect_driver_class=None,
        collect_driver_class=None,
        num_global_steps=1000000,
        train_steps_per_iteration=1,
        train_metrics=None,
        # Safety Critic training args
        train_sc_steps=10,
        train_sc_interval=300,
        online_critic=False,
        # Params for eval
        run_eval=False,
        num_eval_episodes=30,
        eval_interval=1000,
        eval_metrics_callback=None,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        keep_rb_checkpoint=False,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        early_termination_fn=None,
        env_metric_factories=None):  # pylint: disable=unused-argument
    """A simple train and eval for SC-SAC."""

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    train_metrics = train_metrics or []

    if run_eval:
        eval_dir = os.path.join(root_dir, 'eval')
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = env_load_fn(env_name)
        if not isinstance(tf_env, tf_py_environment.TFPyEnvironment):
            tf_env = tf_py_environment.TFPyEnvironment(tf_env)

        if run_eval:
            eval_py_env = env_load_fn(env_name)
            eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        print('obs spec:', observation_spec)
        print('action spec:', action_spec)

        if online_critic:
            resample_metric = tf_py_metric.TfPyMetric(
                py_metrics.CounterMetric('unsafe_ac_samples'))
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   train_step_counter=global_step,
                                   resample_metric=resample_metric)
        else:
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   train_step_counter=global_step)

        tf_agent.initialize()

        # Make the replay buffer.
        collect_data_spec = tf_agent.collect_data_spec

        logging.info('Allocating replay buffer ...')
        # Add to replay buffer and other agent specific observers.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            collect_data_spec, max_length=1000000)
        logging.info('RB capacity: %i', replay_buffer.capacity)
        logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec)

        agent_observers = [replay_buffer.add_batch]
        if online_critic:
            online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                collect_data_spec, max_length=10000)

            online_rb_ckpt_dir = os.path.join(train_dir,
                                              'online_replay_buffer')
            online_rb_checkpointer = common.Checkpointer(
                ckpt_dir=online_rb_ckpt_dir,
                max_to_keep=1,
                replay_buffer=online_replay_buffer)

            clear_rb = common.function(online_replay_buffer.clear)
            agent_observers.append(online_replay_buffer.add_batch)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

        if not online_critic:
            eval_policy = tf_agent.policy
        else:
            eval_policy = tf_agent._safe_policy  # pylint: disable=protected-access

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        if not online_critic:
            collect_policy = tf_agent.collect_policy
        else:
            collect_policy = tf_agent._safe_policy  # pylint: disable=protected-access

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        safety_critic_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'safety_critic'),
            safety_critic=tf_agent._safety_critic_network,  # pylint: disable=protected-access
            global_step=global_step)
        rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
        rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir,
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        if load_root_dir:
            load_root_dir = os.path.expanduser(load_root_dir)
            load_train_dir = os.path.join(load_root_dir, 'train')
            misc.load_pi_ckpt(load_train_dir, tf_agent)  # loads tf_agent

        if load_root_dir is None:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()
        safety_critic_checkpointer.initialize_or_restore()

        collect_driver = collect_driver_class(tf_env,
                                              collect_policy,
                                              observers=agent_observers +
                                              train_metrics)

        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        if not rb_checkpointer.checkpoint_exists:
            logging.info('Performing initial collection ...')
            common.function(
                initial_collect_driver_class(tf_env,
                                             initial_collect_policy,
                                             observers=agent_observers +
                                             train_metrics).run)()
            last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
            logging.info('Data saved after initial collection: %d steps',
                         last_id)
            tf.print(
                replay_buffer._get_rows_for_id(last_id),  # pylint: disable=protected-access
                output_stream=logging.info)

        if run_eval:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics',
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics)
            if FLAGS.viz_pm:
                eval_fig_dir = osp.join(eval_dir, 'figs')
                if not tf.io.gfile.isdir(eval_fig_dir):
                    tf.io.gfile.makedirs(eval_fig_dir)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)
        if online_critic:
            online_dataset = online_replay_buffer.as_dataset(
                num_parallel_calls=3, num_steps=2).prefetch(3)
            online_iterator = iter(online_dataset)

            @common.function
            def critic_train_step():
                """Builds critic training step."""
                experience, buf_info = next(online_iterator)
                if env_name in [
                        'IndianWell', 'IndianWell2', 'IndianWell3',
                        'DrunkSpider', 'DrunkSpiderShort'
                ]:
                    safe_rew = experience.observation['task_agn_rew']
                else:
                    safe_rew = agents.process_replay_buffer(
                        online_replay_buffer, as_tensor=True)
                    safe_rew = tf.gather(safe_rew,
                                         tf.squeeze(buf_info.ids),
                                         axis=1)
                ret = tf_agent.train_sc(experience, safe_rew)
                clear_rb()
                return ret

        @common.function
        def train_step():
            experience, _ = next(iterator)
            ret = tf_agent.train(experience)
            return ret

        if not early_termination_fn:
            early_termination_fn = lambda: False

        loss_diverged = False
        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0
        mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')
        if online_critic:
            mean_resample_ac = tf.keras.metrics.Mean(
                name='mean_unsafe_ac_samples')
            resample_metric.reset()

        while (global_step.numpy() <= num_global_steps
               and not early_termination_fn()):
            # Collect and train.
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            if online_critic:
                mean_resample_ac(resample_metric.result())
                resample_metric.reset()
                if time_step.is_last():
                    resample_ac_freq = mean_resample_ac.result()
                    mean_resample_ac.reset_states()
                    tf.compat.v2.summary.scalar(name='unsafe_ac_samples',
                                                data=resample_ac_freq,
                                                step=global_step)

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
                mean_train_loss(train_loss.loss)

            if online_critic:
                if global_step.numpy() % train_sc_interval == 0:
                    for _ in range(train_sc_steps):
                        sc_loss, lambda_loss = critic_train_step()  # pylint: disable=unused-variable

            total_loss = mean_train_loss.result()
            mean_train_loss.reset_states()
            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    loss_diverged = True
                    break
            else:
                loss_divergence_counter = 0

            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             total_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                safety_critic_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                if online_critic:
                    online_rb_checkpointer.save(global_step=global_step_val)
                rb_checkpointer.save(global_step=global_step_val)

            if run_eval and global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
                if FLAGS.viz_pm:
                    savepath = 'step{}.png'.format(global_step_val)
                    savepath = osp.join(eval_fig_dir, savepath)
                    misc.record_episode_vis_summary(eval_tf_env, eval_policy,
                                                    savepath)

    if not keep_rb_checkpoint:
        misc.cleanup_checkpoints(rb_ckpt_dir)

    if loss_diverged:
        # Raise an error at the very end after the cleanup.
        raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
            total_loss, global_step.numpy()))

    return total_loss
示例#26
0
def train_eval(
    root_dir,
    experiment_name,  # experiment name
    env_name='carla-v0',
    num_iterations=int(1e7),
    model_network_ctor_type='non-hierarchical',  # model net
    input_names=['camera', 'lidar'],  # names for inputs
    reconstruct_names=['roadmap'],  # names for masks
    pixor_names=['vh_clas', 'vh_regr', 'pixor_state'],  # names for pixor outputs
    reconstruct_pixor_state=True,  # whether to reconstruct pixor_state
    extra_names=['state'],  # extra inputs
    obs_size=64,  # size of observation image
    pixor_size=64,  # size of pixor output image
    perception_weight=1.0,  # weight of perception part loss
    # Params for collect
    initial_collect_steps=1000,
    replay_buffer_capacity=int(5e4+1),
    # Params for train
    training=True,  # whether to train, or just evaluate
    model_batch_size=32,  # model training batch size
    sequence_length=10,  # number of timesteps to train model
    model_learning_rate=1e-4,  # learning rate for model training
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=2000,
    # Params for summaries and logging
    num_images_per_summary=1,  # images for each summary
    train_checkpoint_interval=2000,
    log_interval=200,
    summary_interval=2000,
    summaries_flush_secs=10,
    summarize_grads_and_vars=False,
    gpu_allow_growth=True,  # GPU memory growth
    gpu_memory_limit=None,  # GPU memory limit
    action_repeat=1):  # Name of single observation channel, ['camera', 'lidar', 'birdeye']
  """A simple train and eval for SLAC."""
  # Setup GPU
  gpus = tf.config.experimental.list_physical_devices('GPU')
  if gpu_allow_growth:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  if gpu_memory_limit:
    for gpu in gpus:
      tf.config.experimental.set_virtual_device_configuration(
          gpu,
          [tf.config.experimental.VirtualDeviceConfiguration(
              memory_limit=gpu_memory_limit)])

  # Get train and eval direction
  root_dir = os.path.expanduser(root_dir)
  root_dir = os.path.join(root_dir, env_name, experiment_name)

  # Get summary writers
  summary_writer = tf.summary.create_file_writer(
      root_dir, flush_millis=summaries_flush_secs * 1000)
  summary_writer.set_as_default()

  # Eval metrics
  eval_metrics = [
      tf_metrics.AverageReturnMetric(
        name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(
        name='AverageEpisodeLengthEvalPolicy',
        buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()

  # Whether to record for summary
  with tf.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    # Create Carla environment
    py_env, eval_py_env = load_carla_env(env_name='carla-v0', lidar_bin=32/obs_size, pixor_size=pixor_size,
      obs_channels=list(set(input_names+reconstruct_names+pixor_names+extra_names)), action_repeat=action_repeat)

    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
    fps = int(np.round(1.0 / (py_env.dt * action_repeat)))

    # Specs
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    # Get model network
    if model_network_ctor_type == 'hierarchical':
      model_network_ctor = sequential_latent_pixor_network.PixorSLMHierarchical
    else:
      raise NotImplementedError
    model_net = model_network_ctor(
      input_names, reconstruct_names, obs_size=obs_size, pixor_size=pixor_size,
      reconstruct_pixor_state=reconstruct_pixor_state, perception_weight=perception_weight)

    # Build the perception agent
    actor_network = state_based_heuristic_actor_network.StateBasedHeuristicActorNetwork(
        observation_spec['state'],
        action_spec,
        desired_speed=9
        )

    tf_agent = perception_agent.PerceptionAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_network,
        model_network=model_net,
        model_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=model_learning_rate),
        num_images_per_summary=num_images_per_summary,
        sequence_length=sequence_length,
        gradient_clipping=gradient_clipping,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step,
        fps=fps)
    tf_agent.initialize()

    # Train metrics
    env_steps = tf_metrics.EnvironmentSteps()
    average_return = tf_metrics.AverageReturnMetric(
        buffer_size=num_eval_episodes,
        batch_size=tf_env.batch_size)
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        env_steps,
        average_return,
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
    ]

    # Get policies
    eval_policy = tf_agent.policy
    initial_collect_policy = tf_agent.collect_policy

    # Checkpointers
    train_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'train'),
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
        max_to_keep=2)
    train_checkpointer.initialize_or_restore()

    model_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'model'),
        model=model_net,
        max_to_keep=2)

    # Evaluation
    compute_summaries(
      eval_metrics,
      eval_tf_env,
      eval_policy,
      train_step=global_step,
      summary_writer=summary_writer,
      num_episodes=num_eval_episodes,
      num_episodes_to_render=num_images_per_summary,
      model_net=model_net,
      fps=10,
      image_keys=['camera', 'lidar', 'roadmap'],
      pixor_size=pixor_size)

    # Collect/restore data and train
    if training:
      # Get replay buffer
      replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
          data_spec=tf_agent.collect_data_spec,
          batch_size=1,  # No parallel environments
          max_length=replay_buffer_capacity)
      replay_observer = [replay_buffer.add_batch]

      # Replay buffer checkpointer
      rb_checkpointer = common.Checkpointer(
          ckpt_dir=os.path.join(root_dir, 'replay_buffer'),
          max_to_keep=1,
          replay_buffer=replay_buffer)
      rb_checkpointer.initialize_or_restore()

      # Collect driver
      initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
          tf_env,
          initial_collect_policy,
          observers=replay_observer + train_metrics,
          num_steps=initial_collect_steps)

      # Optimize the performance by using tf functions
      initial_collect_driver.run = common.function(initial_collect_driver.run)

      # Collect initial replay data.
      if (global_step.numpy() == 0 and replay_buffer.num_frames() == 0):
        logging.info(
            'Collecting experience for %d steps '
            'with a model-based policy.', initial_collect_steps)
        initial_collect_driver.run()
        rb_checkpointer.save(global_step=global_step.numpy())

      # Dataset generates trajectories with shape [Bxslx...]
      dataset = replay_buffer.as_dataset(
          num_parallel_calls=3,
          sample_batch_size=model_batch_size,
          num_steps=sequence_length + 1).prefetch(3)
      iterator = iter(dataset)

      # Get train model step
      def train_step():
        experience, _ = next(iterator)
        return tf_agent.train(experience)
      train_step = common.function(train_step)

      # Start training
      for iteration in range(num_iterations):

        loss = train_step()

        # Log training information
        if global_step.numpy() % log_interval == 0:
          logging.info('global steps = %d, model loss = %f', global_step.numpy(), loss.loss)

        # Get training metrics
        for train_metric in train_metrics:
          train_metric.tf_summaries(train_step=global_step.numpy())

        # Evaluation
        if global_step.numpy() % eval_interval == 0:
          # Log evaluation metrics
          compute_summaries(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            train_step=global_step,
            summary_writer=summary_writer,
            num_episodes=num_eval_episodes,
            num_episodes_to_render=num_images_per_summary,
            model_net=model_net,
            fps=10,
            image_keys=['camera', 'lidar', 'roadmap'],
            pixor_size=pixor_size)

        # Save checkpoints
        global_step_val = global_step.numpy()
        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)
          model_checkpointer.save(global_step=global_step_val)
示例#27
0
    batch_size=tf_env.batch_size,
    # This can store 4 million trajectories (note: requires a lot of RAM)
    max_length=n_iterations)

# Create the observer that adds trajectories to the replay buffer
replay_buffer_observer = replay_buffer.add_batch

## ------------------------------------------------------------------------------
## ------------------------------------------------------------------------------
## ------------------------------------------------------------------------------

train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric(),
]

logging.getLogger().setLevel(logging.INFO)

## ------------------------------------------------------------------------------
## ------------------------------------------------------------------------------
## ------------------------------------------------------------------------------

collect_driver = DynamicStepDriver(
    tf_env,  # Env to play with
    agent.collect_policy,  # Collect policy of the agent
    observers=[replay_buffer_observer] +
    train_metrics,  # pass to all observers
    num_steps=1)
# Speed up as tensorflow function
示例#28
0
def run():
    tf_env = tf_py_environment.TFPyEnvironment(SnakeEnv())
    eval_env = tf_py_environment.TFPyEnvironment(SnakeEnv(step_limit=50))

    q_net = q_network.QNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        conv_layer_params=(),
        fc_layer_params=(512, 256, 128),
    )

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    global_counter = tf.compat.v1.train.get_or_create_global_step()

    agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_counter,
        gamma=0.95,
        epsilon_greedy=0.1,
        n_step_update=1,
    )

    root_dir = os.path.join('/tf-logs', 'snake')
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_max_length,
    )

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        agent.collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration,
    )

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=agent,
        global_step=global_counter,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
    )

    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=agent.policy,
        global_step=global_counter,
    )

    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer,
    )

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    collect_driver.run = common.function(collect_driver.run)
    agent.train = common.function(agent.train)

    random_policy = random_tf_policy.RandomTFPolicy(tf_env.time_step_spec(),
                                                    tf_env.action_spec())

    if replay_buffer.num_frames() >= initial_collect_steps:
        logging.info("We loaded memories, not doing random seed")
    else:
        logging.info("Capturing %d steps to seed with random memories",
                     initial_collect_steps)

        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            random_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

    train_summary_writer = tf.summary.create_file_writer(train_dir)
    train_summary_writer.set_as_default()

    avg_returns = []
    avg_return_metric = tf_metrics.AverageReturnMetric(
        buffer_size=num_eval_episodes)
    eval_metrics = [
        avg_return_metric,
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]
    logging.info("Running initial evaluation")
    results = metric_utils.eager_compute(
        eval_metrics,
        eval_env,
        agent.policy,
        num_episodes=num_eval_episodes,
        train_step=global_counter,
        summary_writer=tf.summary.create_file_writer(eval_dir),
        summary_prefix='Metrics',
    )
    avg_returns.append(
        (global_counter.numpy(), avg_return_metric.result().numpy()))
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_counter.numpy()
    time_acc = 0

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)

    iterator = iter(dataset)

    @common.function
    def train_step():
        experience, _ = next(iterator)
        return agent.train(experience)

    for _ in range(num_iterations):
        start_time = time.time()
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )

        for _ in range(train_steps_per_iteration):
            train_loss = train_step()
        time_acc += time.time() - start_time

        step = global_counter.numpy()

        if step % log_interval == 0:
            logging.info("step = %d, loss = %f", step, train_loss.loss)
            steps_per_sec = (step - timed_at_step) / time_acc
            logging.info("%.3f steps/sec", steps_per_sec)
            timed_at_step = step
            time_acc = 0

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_counter,
                                      step_metrics=train_metrics[:2])

        if step % train_checkpoint_interval == 0:
            train_checkpointer.save(global_step=step)

        if step % policy_checkpoint_interval == 0:
            policy_checkpointer.save(global_step=step)

        if step % rb_checkpoint_interval == 0:
            rb_checkpointer.save(global_step=step)

        if step % capture_interval == 0:
            print("Capturing run:")
            capture_run(os.path.join(root_dir, "snake" + str(step) + ".mp4"),
                        eval_env, agent.policy)

        if step % eval_interval == 0:
            print("EVALUTION TIME:")
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_env,
                agent.policy,
                num_episodes=num_eval_episodes,
                train_step=global_counter,
                summary_writer=tf.summary.create_file_writer(eval_dir),
                summary_prefix='Metrics',
            )
            metric_utils.log_metrics(eval_metrics)
            avg_returns.append(
                (global_counter.numpy(), avg_return_metric.result().numpy()))
示例#29
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v1',
    env_load_fn=suite_mujoco.load,
    num_iterations=2000000,
    actor_fc_layers=(400, 300),
    critic_obs_fc_layers=(400,),
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(300,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    num_parallel_environments=1,
    replay_buffer_capacity=100000,
    ou_stddev=0.2,
    ou_damping=0.15,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    actor_learning_rate=1e-4,
    critic_learning_rate=1e-3,
    dqda_clipping=None,
    td_errors_loss_fn=tf.losses.huber_loss,
    gamma=0.995,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for DDPG."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.contrib.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.contrib.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  # TODO(kbanoop): Figure out if it is possible to avoid the with block.
  with tf.contrib.summary.record_summaries_every_n_global_steps(
      summary_interval):
    if num_parallel_environments > 1:
      tf_env = tf_py_environment.TFPyEnvironment(
          parallel_py_environment.ParallelPyEnvironment(
              [lambda: env_load_fn(env_name)] * num_parallel_environments))
    else:
      tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    eval_py_env = env_load_fn(env_name)

    actor_net = actor_network.ActorNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers,
    )

    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
    )

    tf_agent = ddpg_agent.DdpgAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        ou_stddev=ou_stddev,
        ou_damping=ou_damping,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec(),
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    global_step = tf.train.get_or_create_global_step()

    collect_policy = tf_agent.collect_policy()
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=initial_collect_steps).run()

    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)

    iterator = dataset.make_initializable_iterator()
    trajectories, unused_info = iterator.get_next()
    train_op = tf_agent.train(
        experience=trajectories, train_step_counter=global_step)

    train_checkpointer = common_utils.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=tf.contrib.checkpoint.List(train_metrics))
    policy_checkpointer = common_utils.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy(),
        global_step=global_step)
    rb_checkpointer = common_utils.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    for train_metric in train_metrics:
      train_metric.tf_summaries(step_metrics=train_metrics[:2])
    summary_op = tf.contrib.summary.all_summary_ops()

    with eval_summary_writer.as_default(), \
         tf.contrib.summary.always_record_summaries():
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries()

    init_agent_op = tf_agent.initialize()

    with tf.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      # TODO(sguada) Remove once Periodically can be saved.
      common_utils.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      tf.contrib.summary.initialize(session=sess)
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
      )

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_op, global_step])

      timed_at_step = sess.run(global_step)
      time_acc = 0
      steps_per_second_ph = tf.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.contrib.summary.scalar(
          name='global_steps/sec', tensor=steps_per_second_ph)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _, global_step_val = train_step_call()
        time_acc += time.time() - start_time

        if global_step_val % log_interval == 0:
          tf.logging.info('step = %d, loss = %f', global_step_val,
                          loss_info_value.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          tf.logging.info('%.3f steps/sec' % steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
          )
def train_eval(
    root_dir,
    experiment_name,  # experiment name
    env_name='carla-v0',
    agent_name='sac',  # agent's name
    num_iterations=int(1e7),
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    model_network_ctor_type='non-hierarchical',  # model net
    input_names=['camera', 'lidar'],  # names for inputs
    mask_names=['birdeye'],  # names for masks
    preprocessing_combiner=tf.keras.layers.Add(
    ),  # takes a flat list of tensors and combines them
    actor_lstm_size=(40, ),  # lstm size for actor
    critic_lstm_size=(40, ),  # lstm size for critic
    actor_output_fc_layers=(100, ),  # lstm output
    critic_output_fc_layers=(100, ),  # lstm output
    epsilon_greedy=0.1,  # exploration parameter for DQN
    q_learning_rate=1e-3,  # q learning rate for DQN
    ou_stddev=0.2,  # exploration paprameter for DDPG
    ou_damping=0.15,  # exploration parameter for DDPG
    dqda_clipping=None,  # for DDPG
    exploration_noise_std=0.1,  # exploration paramter for td3
    actor_update_period=2,  # for td3
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=int(1e5),
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    initial_model_train_steps=100000,  # initial model training
    batch_size=256,
    model_batch_size=32,  # model training batch size
    sequence_length=4,  # number of timesteps to train model
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    model_learning_rate=1e-4,  # learning rate for model training
    td_errors_loss_fn=tf.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    num_images_per_summary=1,  # images for each summary
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    gpu_allow_growth=True,  # GPU memory growth
    gpu_memory_limit=None,  # GPU memory limit
    action_repeat=1
):  # Name of single observation channel, ['camera', 'lidar', 'birdeye']
    # Setup GPU
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    # Get train and eval directories
    root_dir = os.path.expanduser(root_dir)
    root_dir = os.path.join(root_dir, env_name, experiment_name)

    # Get summary writers
    summary_writer = tf.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    # Eval metrics
    eval_metrics = [
        tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy',
                                       buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthEvalPolicy',
            buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    # Whether to record for summary
    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create Carla environment
        if agent_name == 'latent_sac':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names +
                                                 mask_names,
                                                 action_repeat=action_repeat)
        elif agent_name == 'dqn':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 discrete=True,
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)
        else:
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)

        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
        fps = int(np.round(1.0 / (py_env.dt * action_repeat)))

        # Specs
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        ## Make tf agent
        if agent_name == 'latent_sac':
            # Get model network for latent sac
            if model_network_ctor_type == 'hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical
            elif model_network_ctor_type == 'non-hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical
            else:
                raise NotImplementedError
            model_net = model_network_ctor(input_names,
                                           input_names + mask_names)

            # Get the latent spec
            latent_size = model_net.latent_size
            latent_observation_spec = tensor_spec.TensorSpec((latent_size, ),
                                                             dtype=tf.float32)
            latent_time_step_spec = ts.time_step_spec(
                observation_spec=latent_observation_spec)

            # Get actor and critic net
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                latent_observation_spec,
                action_spec,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=normal_projection_net)
            critic_net = critic_network.CriticNetwork(
                (latent_observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers)

            # Build the inner SAC agent based on latent space
            inner_agent = sac_agent.SacAgent(
                latent_time_step_spec,
                action_spec,
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate),
                alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=alpha_learning_rate),
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)
            inner_agent.initialize()

            # Build the latent sac agent
            tf_agent = latent_sac_agent.LatentSACAgent(
                time_step_spec,
                action_spec,
                inner_agent=inner_agent,
                model_network=model_net,
                model_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=model_learning_rate),
                model_batch_size=model_batch_size,
                num_images_per_summary=num_images_per_summary,
                sequence_length=sequence_length,
                gradient_clipping=gradient_clipping,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                fps=fps)

        else:
            # Set up preprosessing layers for dictionary observation inputs
            preprocessing_layers = collections.OrderedDict()
            for name in input_names:
                preprocessing_layers[name] = Preprocessing_Layer(32, 256)
            if len(input_names) < 2:
                preprocessing_combiner = None

            if agent_name == 'dqn':
                q_rnn_net = q_rnn_network.QRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = dqn_agent.DqnAgent(
                    time_step_spec,
                    action_spec,
                    q_network=q_rnn_net,
                    epsilon_greedy=epsilon_greedy,
                    n_step_update=1,
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=q_learning_rate),
                    td_errors_loss_fn=common.element_wise_squared_loss,
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            elif agent_name == 'ddpg' or agent_name == 'td3':
                actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                if agent_name == 'ddpg':
                    tf_agent = ddpg_agent.DdpgAgent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        ou_stddev=ou_stddev,
                        ou_damping=ou_damping,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars)
                elif agent_name == 'td3':
                    tf_agent = td3_agent.Td3Agent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        exploration_noise_std=exploration_noise_std,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        actor_update_period=actor_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars,
                        train_step_counter=global_step)

            elif agent_name == 'sac':
                actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers,
                    continuous_projection_net=normal_projection_net)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = sac_agent.SacAgent(
                    time_step_spec,
                    action_spec,
                    actor_network=actor_distribution_rnn_net,
                    critic_network=critic_rnn_net,
                    actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=actor_learning_rate),
                    critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=critic_learning_rate),
                    alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=alpha_learning_rate),
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    td_errors_loss_fn=tf.math.
                    squared_difference,  # make critic loss dimension compatible
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            else:
                raise NotImplementedError

        tf_agent.initialize()

        # Get replay buffer
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,  # No parallel environments
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        # Train metrics
        env_steps = tf_metrics.EnvironmentSteps()
        average_return = tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        # Get policies
        # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        eval_policy = tf_agent.policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        collect_policy = tf_agent.collect_policy

        # Checkpointers
        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=2)
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step,
                                                  max_to_keep=2)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)
        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        # Collect driver
        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        # Optimize the performance by using tf functions
        initial_collect_driver.run = common.function(
            initial_collect_driver.run)
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if (env_steps.result() == 0 or replay_buffer.num_frames() == 0):
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps'
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        if agent_name == 'latent_sac':
            compute_summaries(eval_metrics,
                              eval_tf_env,
                              eval_policy,
                              train_step=global_step,
                              summary_writer=summary_writer,
                              num_episodes=1,
                              num_episodes_to_render=1,
                              model_net=model_net,
                              fps=10,
                              image_keys=input_names + mask_names)
        else:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=1,
                train_step=env_steps.result(),
                summary_writer=summary_writer,
                summary_prefix='Eval',
            )
            metric_utils.log_metrics(eval_metrics)

        # Dataset generates trajectories with shape [Bxslx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        # Get train step
        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        train_step = common.function(train_step)

        if agent_name == 'latent_sac':

            def train_model_step():
                experience, _ = next(iterator)
                return tf_agent.train_model(experience)

            train_model_step = common.function(train_model_step)

        # Training initializations
        time_step = None
        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Start training
        for iteration in range(num_iterations):
            start_time = time.time()

            if agent_name == 'latent_sac' and iteration < initial_model_train_steps:
                train_model_step()
            else:
                # Run collect
                time_step, _ = collect_driver.run(time_step=time_step)

                # Train an iteration
                for _ in range(train_steps_per_iteration):
                    train_step()

            time_acc += time.time() - start_time

            # Log training information
            if global_step.numpy() % log_interval == 0:
                logging.info('env steps = %d, average return = %f',
                             env_steps.result(), average_return.result())
                env_steps_per_sec = (env_steps.result().numpy() -
                                     env_steps_before) / time_acc
                logging.info('%.3f env steps/sec', env_steps_per_sec)
                tf.summary.scalar(name='env_steps_per_sec',
                                  data=env_steps_per_sec,
                                  step=env_steps.result())
                time_acc = 0
                env_steps_before = env_steps.result().numpy()

            # Get training metrics
            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=env_steps.result())

            # Evaluation
            if global_step.numpy() % eval_interval == 0:
                # Log evaluation metrics
                if agent_name == 'latent_sac':
                    compute_summaries(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        train_step=global_step,
                        summary_writer=summary_writer,
                        num_episodes=num_eval_episodes,
                        num_episodes_to_render=num_images_per_summary,
                        model_net=model_net,
                        fps=10,
                        image_keys=input_names + mask_names)
                else:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    metric_utils.log_metrics(eval_metrics)

            # Save checkpoints
            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)