コード例 #1
0
ファイル: train_eval.py プロジェクト: fxia22/agents
def train_eval(
        root_dir,
        gpu=0,
        env_load_fn=None,
        model_ids=None,
        eval_env_mode='headless',
        num_iterations=1000000,
        conv_layer_params=None,
        encoder_fc_layers=[256],
        actor_fc_layers=[400, 300],
        critic_obs_fc_layers=[400],
        critic_action_fc_layers=None,
        critic_joint_fc_layers=[300],
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        num_parallel_environments=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        eval_only=False,
        eval_deterministic=False,
        num_parallel_environments_eval=1,
        model_ids_eval=None,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=10000,
        rb_checkpoint_interval=50000,
        log_interval=100,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            py_metrics.AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments_eval),
        batched_py_metric.BatchedPyMetric(
            py_metrics.AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments_eval),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if model_ids is None:
            model_ids = [None] * num_parallel_environments
        else:
            assert len(model_ids) == num_parallel_environments, \
                'model ids provided, but length not equal to num_parallel_environments'

        if model_ids_eval is None:
            model_ids_eval = [None] * num_parallel_environments_eval
        else:
            assert len(model_ids_eval) == num_parallel_environments_eval,\
                'model ids eval provided, but length not equal to num_parallel_environments_eval'

        tf_py_env = [
            lambda model_id=model_ids[i]: env_load_fn(model_id, 'headless', gpu
                                                      )
            for i in range(num_parallel_environments)
        ]
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(tf_py_env))

        if eval_env_mode == 'gui':
            assert num_parallel_environments_eval == 1, 'only one GUI env is allowed'
        eval_py_env = [
            lambda model_id=model_ids_eval[i]: env_load_fn(
                model_id, eval_env_mode, gpu)
            for i in range(num_parallel_environments_eval)
        ]
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            eval_py_env)

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()
        print('observation_spec', observation_spec)
        print('action_spec', action_spec)

        glorot_uniform_initializer = tf.compat.v1.keras.initializers.glorot_uniform(
        )
        preprocessing_layers = {
            'depth_seg':
            tf.keras.Sequential(
                mlp_layers(
                    conv_layer_params=conv_layer_params,
                    fc_layer_params=encoder_fc_layers,
                    kernel_initializer=glorot_uniform_initializer,
                )),
            'sensor':
            tf.keras.Sequential(
                mlp_layers(
                    conv_layer_params=None,
                    fc_layer_params=encoder_fc_layers,
                    kernel_initializer=glorot_uniform_initializer,
                )),
        }
        preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

        actor_net = actor_network.ActorNetwork(
            observation_spec,
            action_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            fc_layer_params=actor_fc_layers,
            kernel_initializer=glorot_uniform_initializer,
        )

        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer=glorot_uniform_initializer,
        )

        tf_agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.compat.v1.Session(config=config)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        if eval_deterministic:
            eval_py_policy = py_tf_policy.PyTFPolicy(
                greedy_policy.GreedyPolicy(tf_agent.policy))
        else:
            eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                buffer_size=100, batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=100, batch_size=num_parallel_environments),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps * num_parallel_environments).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration *
            num_parallel_environments).run()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=5,
            sample_batch_size=5 * batch_size,
            num_steps=2).apply(tf.data.experimental.unbatch()).filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()
        train_op = tf_agent.train(trajectories)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics))

        with eval_summary_writer.as_default(), tf.compat.v2.summary.record_if(
                True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=step_metrics)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        init_agent_op = tf_agent.initialize()
        with sess.as_default():
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)

            if eval_only:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=0,
                    callback=eval_metrics_callback,
                    tf_summaries=False,
                    log=True,
                )
                episodes = eval_py_env.get_stored_episodes()
                episodes = [
                    episode for sublist in episodes for episode in sublist
                ][:num_eval_episodes]
                metrics = episode_utils.get_metrics(episodes)
                for key in sorted(metrics.keys()):
                    print(key, ':', metrics[key])

                save_path = os.path.join(eval_dir, 'episodes_vis.pkl')
                episode_utils.save(episodes, save_path)
                print('EVAL DONE')
                return

            # Initialize training.
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            global_step_val = sess.run(global_step)
            if global_step_val == 0:
                # Initial eval of randomly initialized policy
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=0,
                    callback=eval_metrics_callback,
                    tf_summaries=True,
                    log=True,
                )
                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = sess.run(global_step)
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                # print('collect:', time.time() - start_time)

                # train_start_time = time.time()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                # print('train:', time.time() - train_start_time)

                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=0,
                        callback=eval_metrics_callback,
                        tf_summaries=True,
                        log=True,
                    )
                    with eval_summary_writer.as_default(
                    ), tf.compat.v2.summary.record_if(True):
                        with tf.name_scope('Metrics/'):
                            episodes = eval_py_env.get_stored_episodes()
                            episodes = [
                                episode for sublist in episodes
                                for episode in sublist
                            ][:num_eval_episodes]
                            metrics = episode_utils.get_metrics(episodes)
                            for key in sorted(metrics.keys()):
                                print(key, ':', metrics[key])
                                metric_op = tf.compat.v2.summary.scalar(
                                    name=key,
                                    data=metrics[key],
                                    step=global_step_val)
                                sess.run(metric_op)
                    sess.run(eval_summary_flush_op)

        sess.close()
コード例 #2
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        num_iterations=100000,
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        # Params for collect
        initial_collect_steps=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=200,
        batch_size=64,
        train_sequence_length=10,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=10000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # TODO(kbanoop): Figure out if it is possible to avoid the with block.
    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []
        environment = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        tf_env = tf_py_environment.TFPyEnvironment(environment)
        eval_py_env = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        actor_net = actor_rnn_network.ActorRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers)

        critic_net = critic_rnn_network.CriticRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
        )

        tf_agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec(),
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        global_step = tf.train.get_or_create_global_step()

        collect_policy = tf_agent.collect_policy()

        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch],
            num_episodes=initial_collect_steps).run()

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = dataset.make_initializable_iterator()
        trajectories, unused_info = iterator.get_next()
        train_op = tf_agent.train(experience=trajectories,
                                  train_step_counter=global_step)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=tf.contrib.checkpoint.List(train_metrics))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy(),
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])
        summary_op = tf.contrib.summary.all_summary_ops()

        with eval_summary_writer.as_default(), \
             tf.contrib.summary.always_record_summaries():
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            tf.contrib.summary.initialize(session=sess)
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(
                [train_op, summary_op, global_step])

            timed_at_step = sess.run(global_step)
            time_acc = 0
            steps_per_second_ph = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _, global_step_val = train_step_call()
                time_acc += time.time() - start_time

                if global_step_val % log_interval == 0:
                    tf.logging.info('step = %d, loss = %f', global_step_val,
                                    loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    tf.logging.info('%.3f steps/sec' % steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
コード例 #3
0
def train_eval(
    root_dir,
    env_name='cartpole',
    task_name='balance',
    observations_allowlist='position',
    num_iterations=100000,
    actor_fc_layers=(400, 300),
    actor_output_fc_layers=(100,),
    actor_lstm_size=(40,),
    critic_obs_fc_layers=(400,),
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(300,),
    critic_output_fc_layers=(100,),
    critic_lstm_size=(40,),
    # Params for collect
    initial_collect_episodes=1,
    collect_episodes_per_iteration=1,
    replay_buffer_capacity=100000,
    ou_stddev=0.2,
    ou_damping=0.15,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    # Params for train
    train_steps_per_iteration=200,
    batch_size=64,
    train_sequence_length=10,
    actor_learning_rate=1e-4,
    critic_learning_rate=1e-3,
    dqda_clipping=None,
    td_errors_loss_fn=None,
    gamma=0.995,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for checkpoints, summaries, and logging
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=True,
    summarize_grads_and_vars=True,
    eval_metrics_callback=None):

  """A simple train and eval for DDPG."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    if observations_allowlist is not None:
      env_wrappers = [
          functools.partial(
              wrappers.FlattenObservationsWrapper,
              observations_allowlist=[observations_allowlist])
      ]
    else:
      env_wrappers = []

    tf_env = tf_py_environment.TFPyEnvironment(
        suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers))
    eval_tf_env = tf_py_environment.TFPyEnvironment(
        suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers))

    actor_net = actor_rnn_network.ActorRnnNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        input_fc_layer_params=actor_fc_layers,
        lstm_size=actor_lstm_size,
        output_fc_layer_params=actor_output_fc_layers)

    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_rnn_network.CriticRnnNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        lstm_size=critic_lstm_size,
        output_fc_layer_params=critic_output_fc_layers,
    )

    tf_agent = ddpg_agent.DdpgAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        ou_stddev=ou_stddev,
        ou_damping=ou_damping,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=initial_collect_episodes)

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=collect_episodes_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    logging.info(
        'Initializing replay buffer by collecting experience for %d episodes '
        'with a random policy.', initial_collect_episodes)
    initial_collect_driver.run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [BxTx...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=train_sequence_length + 1).prefetch(3)
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)

    if use_tf_functions:
      train_step = common.function(train_step)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step.numpy(),
                     train_loss.loss)
        steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step.numpy()
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step.numpy() % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

    return train_loss
コード例 #4
0
                                                    decay_steps,
                                                    end_epsilon,
                                                    power=1.0,
                                                    cycle=False)
train_step = train_utils.create_train_step()

agent = ddpg_agent.DdpgAgent(
    time_step_spec=time_step_spec,
    action_spec=action_spec,
    actor_network=actor_net,
    critic_network=critic_net,
    actor_optimizer=actor_optimizer,
    critic_optimizer=critic_optimizer,
    ou_stddev=.2,
    ou_damping=0.05,
    target_critic_network=None,
    target_update_tau=0.001,
    dqda_clipping=None,
    td_errors_loss_fn=common.element_wise_squared_loss,
    gamma=0.99,
    reward_scale_factor=None,
    gradient_clipping=True,
    debug_summaries=True,
    summarize_grads_and_vars=True,
    train_step_counter=train_step_counter,
    name="Agent")

agent.initialize()

####REPLAY BUFFER####

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
コード例 #5
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    eval_env_name=None,
    env_load_fn=suite_mujoco.load,
    num_iterations=2000000,
    actor_fc_layers=(400, 300),
    critic_obs_fc_layers=(400,),
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(300,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    num_parallel_environments=1,
    replay_buffer_capacity=100000,
    ou_stddev=0.2,
    ou_damping=0.15,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    actor_learning_rate=1e-4,
    critic_learning_rate=1e-3,
    dqda_clipping=None,
    td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
    gamma=0.995,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for DDPG."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    if num_parallel_environments > 1:
      tf_env = tf_py_environment.TFPyEnvironment(
          parallel_py_environment.ParallelPyEnvironment(
              [lambda: env_load_fn(env_name)] * num_parallel_environments))
    else:
      tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    eval_env_name = eval_env_name or env_name
    eval_py_env = env_load_fn(eval_env_name)

    actor_net = actor_network.ActorNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers,
    )

    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
    )

    tf_agent = ddpg_agent.DdpgAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        ou_stddev=ou_stddev,
        ou_damping=ou_damping,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    collect_policy = tf_agent.collect_policy
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=initial_collect_steps).run()

    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)

    iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
    trajectories, unused_info = iterator.get_next()
    train_fn = common.function(tf_agent.train)
    train_op = train_fn(experience=trajectories)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      # TODO(b/126239733) Remove once Periodically can be saved.
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
      )

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops])
      global_step_call = sess.make_callable(global_step)

      timed_at_step = sess.run(global_step)
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _ = train_step_call()
        time_acc += time.time() - start_time
        global_step_val = global_step_call()

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val,
                       loss_info_value.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
              log=True,
          )
コード例 #6
0
    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=[128],
        action_fc_layer_params=[128],
        joint_fc_layer_params=[128],
    )
    agent = ddpg_agent.DdpgAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.001),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.001),
        ou_stddev=ou_stddev,
        ou_damping=ou_damping,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        dqda_clipping=None,
        td_errors_loss_fn=keras.losses.Huber(reduction="none"),
        gamma=0.995,
        reward_scale_factor=1,
        gradient_clipping=None,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        train_step_counter=train_step)
    agent.initialize()

with tf.device('/cpu:0'):
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
コード例 #7
0
# first, create three indenpent single intersection agents
# then, create one coordinate agent
agent_system = []
replay_buffers = []
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=LEARNING_RATE)
for i in range(3):
    actor_network = ActorNetwork(obs_spec4independent_agent,
                                 act_spec4independent_agent)
    critic_network = CriticNetwork(obs_spec4independent_agent,
                                   act_spec4independent_agent,
                                   q_spec4independent_agent)
    agent = ddpg_agent.DdpgAgent(ts_spec4independent_agent,
                                 act_spec4independent_agent,
                                 actor_network,
                                 critic_network,
                                 actor_optimizer=optimizer,
                                 critic_optimizer=optimizer,
                                 gamma=1,
                                 ou_stddev=5.0,
                                 ou_damping=0.9)
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=REPLAY_BUFFER_MAX_LENGTH)

    agent_system.append(agent)
    replay_buffers.append(replay_buffer)
#  create coordinate agent
actor_network = ActorNetwork(obs_spec4coordinate_agent,
                             act_spec4coordinate_agent)
critic_network = CriticNetwork(obs_spec4coordinate_agent,
コード例 #8
0
critic_net = critic_network.CriticNetwork(
    (train_env.observation_spec(), train_env.action_spec()),
    observation_fc_layer_params=None,
    action_fc_layer_params=None,
    joint_fc_layer_params=critic_joint_fc_layer_params)

global_step = tf.compat.v2.Variable(0)

tf_agent = ddpg_agent.DdpgAgent(train_env.time_step_spec(),
                                train_env.action_spec(),
                                actor_network=actor_net,
                                critic_network=critic_net,
                                actor_optimizer=actor_optimizer,
                                critic_optimizer=critic_optimizer,
                                target_update_tau=target_update_tau,
                                target_update_period=target_update_period,
                                td_errors_loss_fn=loss_fn,
                                gamma=gamma,
                                reward_scale_factor=reward_scale_factor,
                                gradient_clipping=gradient_clipping,
                                train_step_counter=global_step)

ddpg_step_list, ddpg_reward_list, ddpg_name, ddpg_loss_step, ddpg_loss_list, ddpg_train_returns = train_and_evaluate_ACagent(
    tf_agent,
    train_env,
    eval_env,
    num_iterations=num_iterations,
    batch_size=batch_size,
    name='ddpg')
コード例 #9
0
eval_py_env = CTMEnv()

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)


actor_network = ActorNetwork(train_env.observation_spec(),train_env.action_spec())
critic_network = CriticNetwork(train_env.observation_spec(),train_env.action_spec(),q_value_spec)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

agent = ddpg_agent.DdpgAgent(train_env.time_step_spec(),
                             train_env.action_spec(),
                             actor_network,
                             critic_network,
                             actor_optimizer=optimizer,
                             critic_optimizer=optimizer,
                             gamma=0.99,
                             ou_stddev=5.0,
                             ou_damping=0.9)

agent.initialize()



eval_policy = agent.policy
collect_policy = agent.collect_policy
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())
expert_policy = ExpertPolicy()
コード例 #10
0
actor_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-4)
critic_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3)
train_step = train_utils.create_train_step()

tf_agent = ddpg_agent.DdpgAgent(
    time_step_spec= time_step_spec,
    action_spec= action_spec,
    actor_network= actor_net,
    critic_network= critic_net,
    actor_optimizer=actor_optimizer,
    critic_optimizer= critic_optimizer,
    ou_stddev= 0.15,
    ou_damping= 0.2,
    target_actor_network= target_actor_net,
    target_critic_network= target_critic_net,
    target_update_tau= 0.001,
    target_update_period= 1,
    dqda_clipping= None,
    td_errors_loss_fn= None,
    gamma= 0.99,
    reward_scale_factor= None,
    gradient_clipping= None,
    debug_summaries= None,
    summarize_grads_and_vars= False,
    train_step_counter= train_step,
    name= "Agent"
)


#server = reverb.Server(tables=[
#    reverb.Table(
コード例 #11
0
def train_eval():
    # ==========================================================================
    # Setup Logging
    # ==========================================================================
    log_dir = get_log_dir(TrainingParameters.ENV, TrainingParameters.ALGO)
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        log_dir + '/train',
        flush_millis=TrainingParameters.LOG_FLUSH_STEP * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        log_dir + '/eval',
        flush_millis=EvaluationParameters.LOG_FLUSH_STEP * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(
            buffer_size=EvaluationParameters.NUM_EVAL_EPISODE),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=EvaluationParameters.NUM_EVAL_EPISODE)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(lambda: tf.math.equal(
            global_step % TrainingParameters.SUMMARY_INTERVAL, 0)):
        # ======================================================================
        # Create Parallel Environment
        # ======================================================================
        if TrainingParameters.NUM_AGENTS > 1:
            tf_env = tf_py_environment.TFPyEnvironment(
                parallel_py_environment.ParallelPyEnvironment([
                    lambda: TrainingParameters.ENV_LOAD_FN(TrainingParameters.
                                                           ENV)
                ] * TrainingParameters.NUM_AGENTS))
        else:
            tf_env = tf_py_environment.TFPyEnvironment(
                TrainingParameters.ENV_LOAD_FN(TrainingParameters.ENV))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            TrainingParameters.ENV_LOAD_FN(TrainingParameters.ENV))

        # ======================================================================
        # Create Actor Network
        # ======================================================================
        actor_net = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=TrainingParameters.ACTOR_LAYERS,
            activation_fn=TrainingParameters.ACTOR_ACTIVATION)

        # ======================================================================
        # Create Critic Network and Agent
        # ======================================================================

        critic_input_tensor_spec = (tf_env.time_step_spec().observation,
                                    tf_env.action_spec())
        if TrainingParameters.ALGO == 'D4PG':
            critic_net = distributional_critic_network.DistributionalCriticNetwork(
                critic_input_tensor_spec,
                num_atoms=TrainingParameters.NUM_ATOMS,
                observation_fc_layer_params=TrainingParameters.
                CRITIC_OBSERVATION_LAYERS,
                action_fc_layer_params=TrainingParameters.CRITIC_ACTION_LAYERS,
                joint_fc_layer_params=TrainingParameters.CRITIC_JOINT_LAYERS,
                activation_fn=TrainingParameters.CRITIC_ACTIVATION)
            tf_agent = d4pg_agent.D4pgAgent(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                actor_network=actor_net,
                critic_network=critic_net,
                min_v=TrainingParameters.V_MIN,
                max_v=TrainingParameters.V_MAX,
                n_step_return=TrainingParameters.N_STEP_RETURN,
                actor_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=TrainingParameters.ACTOR_LEARNING_RATE),
                actor_l2_lambda=TrainingParameters.ACTOR_L2_LAMBDA,
                critic_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=TrainingParameters.CRITIC_LEARNING_RATE),
                critic_l2_lambda=TrainingParameters.CRITIC_L2_LAMBDA,
                ou_stddev=TrainingParameters.OU_STDDEV,
                ou_damping=TrainingParameters.OU_DAMPING,
                target_update_tau=TrainingParameters.TAU,
                target_update_period=TrainingParameters.TARGET_UPDATE_PERIOD,
                dqda_clipping=None,
                td_errors_loss_fn=TrainingParameters.TD_ERROR_LOSS_FN,
                gamma=TrainingParameters.DISCOUNT_RATE,
                reward_scale_factor=TrainingParameters.REWARD_SCALE_FACTOR,
                gradient_clipping=None,
                debug_summaries=True,
                summarize_grads_and_vars=True,
                train_step_counter=global_step)

        elif TrainingParameters.ALGO == 'DDPG':
            critic_net = critic_network.CriticNetwork(
                critic_input_tensor_spec,
                observation_fc_layer_params=TrainingParameters.
                CRITIC_OBSERVATION_LAYERS,
                action_fc_layer_params=TrainingParameters.CRITIC_ACTION_LAYERS,
                joint_fc_layer_params=TrainingParameters.CRITIC_JOINT_LAYERS,
            )
            tf_agent = ddpg_agent.DdpgAgent(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=TrainingParameters.ACTOR_LEARNING_RATE),
                critic_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=TrainingParameters.CRITIC_LEARNING_RATE),
                train_step_counter=global_step,
                gamma=TrainingParameters.DISCOUNT_RATE,
                td_errors_loss_fn=TrainingParameters.TD_ERROR_LOSS_FN,
                target_update_tau=TrainingParameters.TAU,
                target_update_period=TrainingParameters.TARGET_UPDATE_PERIOD)
        else:
            raise ValueError('Received Wrong ALGO in params.py')

        tf_agent.initialize()

        # ======================================================================
        # Setup Replay Buffer, Trajectory Collector, and Training Operator
        # ======================================================================
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(
                batch_size=TrainingParameters.NUM_AGENTS),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=TrainingParameters.NUM_AGENTS),
        ]

        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=TrainingParameters.REPLAY_MEM_SIZE)

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=TrainingParameters.INITIAL_REPLAY_MEM_SIZE)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=TrainingParameters.NUM_DATA_COLLECT)

        # Dataset generates trajectories with shape [B x (N_STEP_RETURN+1) x...]
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=TrainingParameters.BATCH_SIZE,
            num_steps=TrainingParameters.N_STEP_RETURN + 1).prefetch(
                tf.data.experimental.AUTOTUNE)

        dataset_iterator = iter(dataset)

        def train_step():
            experience, _ = next(dataset_iterator)
            return tf_agent.train(experience)

        if TrainingParameters.USE_TF_FUNCTION:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)
            train_step = common.function(train_step)

        # ======================================================================
        # Collecting Data
        # ======================================================================
        initial_collect_driver.run()

        stop_threading_event = threading.Event()
        threads = []
        threads.append(
            threading.Thread(target=run_background,
                             args=(collect_driver.run, stop_threading_event)))

        for thread in threads:
            thread.start()

        # ======================================================================
        # Initial Policy Evaluation
        # ======================================================================
        eval_policy = tf_agent.policy
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=EvaluationParameters.NUM_EVAL_EPISODE,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        metric_utils.log_metrics(eval_metrics)

        # ======================================================================
        # Train the Agent
        # ======================================================================
        timed_at_step = global_step.numpy()
        time_acc = 0
        for _ in range(TrainingParameters.NUM_ITERATION):
            start_time = time.time()
            for _ in range(TrainingParameters.NUM_STEP_PER_ITERATION):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % TrainingParameters.LOG_INTERVAL == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % EvaluationParameters.EVAL_INTERVAL == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=EvaluationParameters.NUM_EVAL_EPISODE,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                metric_utils.log_metrics(eval_metrics)
                policy_saver.PolicySaver(eval_policy).save(log_dir +
                                                           '/eval/policy_%d' %
                                                           global_step)

        stop_threading_event.set()
コード例 #12
0
critic_net = critic_network.CriticNetwork(
    critic_net_input_specs,
    observation_fc_layer_params=critic_obs_layer_params,
    action_fc_layer_params=None,
    joint_fc_layer_params=critic_fc_layer_params,
)

global_step = tf.compat.v1.train.get_or_create_global_step()
tf_agent = ddpg_agent.DdpgAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    actor_network=actor_net,
    critic_network=critic_net,
    actor_optimizer=tf.compat.v1.train.AdamOptimizer(
        learning_rate=actor_learning_rate),
    critic_optimizer=tf.compat.v1.train.AdamOptimizer(
        learning_rate=critic_learning_rate),
    train_step_counter=global_step,
    gamma=0.995,
    td_errors_loss_fn=tf.keras.losses.MSE,
    target_update_tau=target_update_tau,
    target_update_period=target_update_period)
tf_agent.initialize()

eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy


def compute_avg_return(environment, policy, num_episodes=5):

    total_return = 0.0
コード例 #13
0
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
    )

    tf_agent = ddpg_agent.DdpgAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        ou_stddev=ou_stddev,
        ou_damping=ou_damping,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()
    
    random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                    train_env.action_spec())