Beispiel #1
0
    def __init__(self,
                 root_dir,
                 train_step,
                 agent,
                 experience_dataset_fn=None,
                 after_train_strategy_step_fn=None,
                 triggers=None,
                 checkpoint_interval=100000,
                 summary_interval=1000,
                 max_checkpoints_to_keep=3,
                 use_kwargs_in_agent_train=False,
                 strategy=None):
        """Initializes a Learner instance.

    Args:
      root_dir: Main directory path where checkpoints, saved_models, and
        summaries will be written to.
      train_step: a scalar tf.int64 `tf.Variable` which will keep track of the
        number of train steps. This is used for artifacts created like
        summaries, or outputs in the root_dir.
      agent: `tf_agent.TFAgent` instance to train with.
      experience_dataset_fn: a function that will create an instance of a
        tf.data.Dataset used to sample experience for training. Required for
        using the Learner as is. Optional for subclass learners which take a new
        iterator each time when `learner.run` is called.
      after_train_strategy_step_fn: (Optional) callable of the form
        `fn(sample, loss)` which can be used for example to update priorities in
        a replay buffer where sample is pulled from the `experience_iterator`
        and loss is a `LossInfo` named tuple returned from the agent. This is
        called after every train step. It runs using `strategy.run(...)`.
      triggers: List of callables of the form `trigger(train_step)`. After every
        `run` call every trigger is called with the current `train_step` value
        as an np scalar.
      checkpoint_interval: Number of train steps in between checkpoints. Note
        these are placed into triggers and so a check to generate a checkpoint
        only occurs after every `run` call. Set to -1 to disable (this is not
        recommended, because it means that if the pipeline gets preempted, all
        previous progress is lost). This only takes care of the checkpointing
        the training process.  Policies must be explicitly exported through
        triggers.
      summary_interval: Number of train steps in between summaries. Note these
        are placed into triggers and so a check to generate a checkpoint only
        occurs after every `run` call.
      max_checkpoints_to_keep: Maximum number of checkpoints to keep around.
        These are used to recover from pre-emptions when training.
      use_kwargs_in_agent_train: If True the experience from the replay buffer
        is passed into the agent as kwargs. This requires samples from the RB to
        be of the form `dict(experience=experience, kwarg1=kwarg1, ...)`. This
        is useful if you have an agent with a custom argspec.
      strategy: (Optional) `tf.distribute.Strategy` to use during training.
    """
        if checkpoint_interval < 0:
            logging.warning(
                'Warning: checkpointing the training process is manually disabled.'
                'This means training progress will NOT be automatically restored '
                'if the job gets preempted.')

        self._train_dir = os.path.join(root_dir, TRAIN_DIR)
        self.train_summary_writer = tf.compat.v2.summary.create_file_writer(
            self._train_dir, flush_millis=10000)

        self.train_step = train_step
        self._agent = agent
        self.use_kwargs_in_agent_train = use_kwargs_in_agent_train
        self.strategy = strategy or tf.distribute.get_strategy()

        if experience_dataset_fn:
            with self.strategy.scope():
                dataset = self.strategy.experimental_distribute_datasets_from_function(
                    lambda _: experience_dataset_fn())
                self._experience_iterator = iter(dataset)

        self.after_train_strategy_step_fn = after_train_strategy_step_fn
        self.triggers = triggers or []

        # Prevent autograph from going into the agent.
        self._agent.train = tf.autograph.experimental.do_not_convert(
            agent.train)

        checkpoint_dir = os.path.join(self._train_dir, POLICY_CHECKPOINT_DIR)
        with self.strategy.scope():
            agent.initialize()

            self._checkpointer = common.Checkpointer(
                checkpoint_dir,
                max_to_keep=max_checkpoints_to_keep,
                agent=self._agent,
                train_step=self.train_step)
            self._checkpointer.initialize_or_restore()  # pytype: disable=attribute-error

        self.triggers.append(self._get_checkpoint_trigger(checkpoint_interval))
        self.summary_interval = tf.constant(summary_interval, dtype=tf.int64)
Beispiel #2
0
def train_eval(
        root_dir,
        env_name='MaskedCartPole-v0',
        num_iterations=100000,
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),
        train_sequence_length=10,
        # Params for collect
        initial_collect_steps=50,
        collect_episodes_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=10,
        batch_size=128,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=100,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_py_env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

        q_net = q_rnn_network.QRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=input_fc_layer_params,
            lstm_size=lstm_size,
            output_fc_layer_params=output_fc_layer_params)

        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        experience, _ = iterator.get_next()
        loss_info = tf_agent.train(experience=experience)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy,
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            logging.info('Collecting initial experience.')
            sess.run(initial_collect_op)

            # Compute evaluation metrics.
            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(loss_info)
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        log=True,
                        callback=eval_metrics_callback,
                    )
Beispiel #3
0
def train():
    summary_interval = 1000
    summaries_flush_secs = 10
    num_eval_episodes = 5
    root_dir = '/tmp/tensorflow/logs/tfenv01'
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs*1000)
    train_summary_writer.set_as_default()
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs*1000)
    # maybe py_metrics?
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    environment = TradeEnvironment()
    # utils.validate_py_environment(environment, episodes=5)
    # Environments
    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        train_env = tf_py_environment.TFPyEnvironment(environment)
        eval_env = tf_py_environment.TFPyEnvironment(environment)

        num_iterations = 50
        fc_layer_params = (512, )  # ~ (17 + 1001) / 2
        input_fc_layer_params = (50, )
        output_fc_layer_params = (20, )
        lstm_size = (30, )
        
        initial_collect_steps = 20
        collect_steps_per_iteration = 1
        collect_episodes_per_iteration = 1  # the same as above
        batch_size = 64
        replay_buffer_capacity = 10000
        
        train_sequence_length = 10

        gamma = 0.99  # check if 1.0 works as well
        target_update_tau = 0.05
        target_update_period = 5
        epsilon_greedy = 0.1
        gradient_clipping = None
        reward_scale_factor = 1.0

        learning_rate = 1e-2
        log_interval = 30
        eval_interval = 15

        # train_env.observation_spec(),
        q_net = q_rnn_network.QRnnNetwork(
            train_env.time_step_spec().observation,
            train_env.action_spec(),
            input_fc_layer_params=input_fc_layer_params,
            lstm_size=lstm_size,
            output_fc_layer_params=output_fc_layer_params,
        )

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        tf_agent = dqn_agent.DqnAgent(
            train_env.time_step_spec(),
            train_env.action_spec(),
            q_network=q_net,
            optimizer=optimizer,
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=False,
            summarize_grads_and_vars=False,
            train_step_counter=global_step,
        )

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=train_env.batch_size,
            max_length=replay_buffer_capacity,
        )

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Policy which does not allow some actions in certain states
        q_policy = FilteredQPolicy(
            tf_agent._time_step_spec, 
            tf_agent._action_spec, 
            q_network=tf_agent._q_network,
        )

        # Valid policy to pre-fill replay buffer
        initial_collect_policy = DummyTradePolicy(
            train_env.time_step_spec(),
            train_env.action_spec(),
        )
        print('Initial collecting...')
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            train_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps,
        ).run()

        # Main agent's policy; greedy one
        policy = greedy_policy.GreedyPolicy(q_policy)
        # Policy used for evaluation, the same as above
        eval_policy = greedy_policy.GreedyPolicy(q_policy)
    
        tf_agent._policy = policy
        collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
            q_policy, epsilon=tf_agent._epsilon_greedy)
        # Patch random policy for epsilon greedy collect policy
        filtered_random_tf_policy = FilteredRandomTFPolicy(
            time_step_spec=policy.time_step_spec,
            action_spec=policy.action_spec,
        )
        collect_policy._random_policy = filtered_random_tf_policy
        tf_agent._collect_policy = collect_policy
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            train_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration,
        ).run()
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=batch_size,
            num_steps=train_sequence_length+1,
        ).prefetch(3)

        iterator = iter(dataset) 
        experience, _ = next(iterator)
        loss_info = common.function(tf_agent.train)(experience=experience)

        # Checkpoints
        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
        )
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy,
            global_step=global_step,
        )
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer,
        )
        
        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(train_metric.tf_summaries(
                train_step=global_step,
                step_metrics=train_metrics[:2],
            ))

        with eval_summary_writer.as_default(), \
                tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()
    
        with tf.compat.v1.Session() as sess:
            # sess.run(train_summary_writer.init())
            # sess.run(eval_summary_writer.init())
            
            # Initialize the graph
            # tfe.Saver().restore()
            # train_checkpointer.initialize_or_restore()
            # rb_checkpointer.initialize_or_restore()
            # sess.run(iterator.initializer)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            print('Collecting initial experience...')
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([loss_info, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step,
            )

            # Train
            for i in range(num_iterations):
                start_time = time.time()
                collect_call()

                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_inerval == 0:
                    print('step=%d, loss=%f', 



                          global_step_val, loss_info_value.loss)
                    steps_per_sec = (global_step_val-timed_at_step) / time_acc
                    print('%.3f steps/sec', steps_per_sec)
                    sess.run(
                        steps_per_second_summary,
                        feed_dict={steps_per_second_ph: steps_per_sec},
                    )
                    timed_at_step = global_step_val
                    time_acc = 0

                # Save checkpoints
                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                # Evaluate
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        log=True,
                        callback=eval_metrics_callback,
                    )
    print('Done!')        
Beispiel #4
0
def train(
        root_dir,
        load_root_dir=None,
        env_load_fn=None,
        env_name=None,
        num_parallel_environments=1,  # pylint: disable=unused-argument
        agent_class=None,
        initial_collect_random=True,  # pylint: disable=unused-argument
        initial_collect_driver_class=None,
        collect_driver_class=None,
        num_global_steps=1000000,
        train_steps_per_iteration=1,
        train_metrics=None,
        # Safety Critic training args
        train_sc_steps=10,
        train_sc_interval=300,
        online_critic=False,
        # Params for eval
        run_eval=False,
        num_eval_episodes=30,
        eval_interval=1000,
        eval_metrics_callback=None,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        keep_rb_checkpoint=False,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        early_termination_fn=None,
        env_metric_factories=None):  # pylint: disable=unused-argument
    """A simple train and eval for SC-SAC."""

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    train_metrics = train_metrics or []

    if run_eval:
        eval_dir = os.path.join(root_dir, 'eval')
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = env_load_fn(env_name)
        if not isinstance(tf_env, tf_py_environment.TFPyEnvironment):
            tf_env = tf_py_environment.TFPyEnvironment(tf_env)

        if run_eval:
            eval_py_env = env_load_fn(env_name)
            eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        print('obs spec:', observation_spec)
        print('action spec:', action_spec)

        if online_critic:
            resample_metric = tf_py_metric.TfPyMetric(
                py_metrics.CounterMetric('unsafe_ac_samples'))
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   train_step_counter=global_step,
                                   resample_metric=resample_metric)
        else:
            tf_agent = agent_class(time_step_spec,
                                   action_spec,
                                   train_step_counter=global_step)

        tf_agent.initialize()

        # Make the replay buffer.
        collect_data_spec = tf_agent.collect_data_spec

        logging.info('Allocating replay buffer ...')
        # Add to replay buffer and other agent specific observers.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            collect_data_spec, max_length=1000000)
        logging.info('RB capacity: %i', replay_buffer.capacity)
        logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec)

        agent_observers = [replay_buffer.add_batch]
        if online_critic:
            online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                collect_data_spec, max_length=10000)

            online_rb_ckpt_dir = os.path.join(train_dir,
                                              'online_replay_buffer')
            online_rb_checkpointer = common.Checkpointer(
                ckpt_dir=online_rb_ckpt_dir,
                max_to_keep=1,
                replay_buffer=online_replay_buffer)

            clear_rb = common.function(online_replay_buffer.clear)
            agent_observers.append(online_replay_buffer.add_batch)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

        if not online_critic:
            eval_policy = tf_agent.policy
        else:
            eval_policy = tf_agent._safe_policy  # pylint: disable=protected-access

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        if not online_critic:
            collect_policy = tf_agent.collect_policy
        else:
            collect_policy = tf_agent._safe_policy  # pylint: disable=protected-access

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        safety_critic_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'safety_critic'),
            safety_critic=tf_agent._safety_critic_network,  # pylint: disable=protected-access
            global_step=global_step)
        rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
        rb_checkpointer = common.Checkpointer(ckpt_dir=rb_ckpt_dir,
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        if load_root_dir:
            load_root_dir = os.path.expanduser(load_root_dir)
            load_train_dir = os.path.join(load_root_dir, 'train')
            misc.load_pi_ckpt(load_train_dir, tf_agent)  # loads tf_agent

        if load_root_dir is None:
            train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()
        safety_critic_checkpointer.initialize_or_restore()

        collect_driver = collect_driver_class(tf_env,
                                              collect_policy,
                                              observers=agent_observers +
                                              train_metrics)

        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        if not rb_checkpointer.checkpoint_exists:
            logging.info('Performing initial collection ...')
            common.function(
                initial_collect_driver_class(tf_env,
                                             initial_collect_policy,
                                             observers=agent_observers +
                                             train_metrics).run)()
            last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
            logging.info('Data saved after initial collection: %d steps',
                         last_id)
            tf.print(
                replay_buffer._get_rows_for_id(last_id),  # pylint: disable=protected-access
                output_stream=logging.info)

        if run_eval:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics',
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics)
            if FLAGS.viz_pm:
                eval_fig_dir = osp.join(eval_dir, 'figs')
                if not tf.io.gfile.isdir(eval_fig_dir):
                    tf.io.gfile.makedirs(eval_fig_dir)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)
        if online_critic:
            online_dataset = online_replay_buffer.as_dataset(
                num_parallel_calls=3, num_steps=2).prefetch(3)
            online_iterator = iter(online_dataset)

            @common.function
            def critic_train_step():
                """Builds critic training step."""
                experience, buf_info = next(online_iterator)
                if env_name in [
                        'IndianWell', 'IndianWell2', 'IndianWell3',
                        'DrunkSpider', 'DrunkSpiderShort'
                ]:
                    safe_rew = experience.observation['task_agn_rew']
                else:
                    safe_rew = agents.process_replay_buffer(
                        online_replay_buffer, as_tensor=True)
                    safe_rew = tf.gather(safe_rew,
                                         tf.squeeze(buf_info.ids),
                                         axis=1)
                ret = tf_agent.train_sc(experience, safe_rew)
                clear_rb()
                return ret

        @common.function
        def train_step():
            experience, _ = next(iterator)
            ret = tf_agent.train(experience)
            return ret

        if not early_termination_fn:
            early_termination_fn = lambda: False

        loss_diverged = False
        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0
        mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')
        if online_critic:
            mean_resample_ac = tf.keras.metrics.Mean(
                name='mean_unsafe_ac_samples')
            resample_metric.reset()

        while (global_step.numpy() <= num_global_steps
               and not early_termination_fn()):
            # Collect and train.
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            if online_critic:
                mean_resample_ac(resample_metric.result())
                resample_metric.reset()
                if time_step.is_last():
                    resample_ac_freq = mean_resample_ac.result()
                    mean_resample_ac.reset_states()
                    tf.compat.v2.summary.scalar(name='unsafe_ac_samples',
                                                data=resample_ac_freq,
                                                step=global_step)

            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
                mean_train_loss(train_loss.loss)

            if online_critic:
                if global_step.numpy() % train_sc_interval == 0:
                    for _ in range(train_sc_steps):
                        sc_loss, lambda_loss = critic_train_step()  # pylint: disable=unused-variable

            total_loss = mean_train_loss.result()
            mean_train_loss.reset_states()
            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    loss_diverged = True
                    break
            else:
                loss_divergence_counter = 0

            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             total_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                safety_critic_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                if online_critic:
                    online_rb_checkpointer.save(global_step=global_step_val)
                rb_checkpointer.save(global_step=global_step_val)

            if run_eval and global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
                if FLAGS.viz_pm:
                    savepath = 'step{}.png'.format(global_step_val)
                    savepath = osp.join(eval_fig_dir, savepath)
                    misc.record_episode_vis_summary(eval_tf_env, eval_policy,
                                                    savepath)

    if not keep_rb_checkpoint:
        misc.cleanup_checkpoints(rb_ckpt_dir)

    if loss_diverged:
        # Raise an error at the very end after the cleanup.
        raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
            total_loss, global_step.numpy()))

    return total_loss
Beispiel #5
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        eval_env_name=None,
        env_load_fn=suite_mujoco.load,
        num_iterations=2000000,
        actor_fc_layers=(400, 300),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        num_parallel_environments=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if num_parallel_environments > 1:
            tf_env = tf_py_environment.TFPyEnvironment(
                parallel_py_environment.ParallelPyEnvironment(
                    [lambda: env_load_fn(env_name)] *
                    num_parallel_environments))
        else:
            tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        eval_env_name = eval_env_name or env_name
        eval_py_env = env_load_fn(eval_env_name)

        actor_net = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers,
        )

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_network.CriticNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
        )

        tf_agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        trajectories, unused_info = iterator.get_next()
        train_fn = common.function(tf_agent.train)
        train_op = train_fn(experience=trajectories)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(b/126239733) Remove once Periodically can be saved.
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = sess.run(global_step)
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
Beispiel #6
0
def train_eval(
        root_dir,
        env_name='MinitaurGoalVelocityEnv-v0',
        eval_env_name=None,
        env_load_fn=suite_pybullet.load,
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        ensemble=True,
        n_critics=10,
        run_eval=False,
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.keras.losses.mse,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=1.,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    if run_eval:
        eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    if run_eval:
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes)
        ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec, action_spec, fc_layer_params=actor_fc_layers)
        if ensemble:
            critic_nets, critic_optimizers = [], []
            for _ in range(n_critics):
                critic_net = critic_network.CriticNetwork(
                    (observation_spec, action_spec),
                    observation_fc_layer_params=critic_obs_fc_layers,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers)
                critic_optimizers.append(
                    tf.keras.optimizers.Adam(
                        learning_rate=critic_learning_rate))
                critic_nets.append(critic_net)
            tf_agent = ensemble_sac_agent.EnsembleSacAgent(
                time_step_spec,
                action_spec,
                actor_network=actor_net,
                critic_networks=critic_nets,
                actor_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=actor_learning_rate),
                critic_optimizers=critic_optimizers,
                alpha_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=alpha_learning_rate),
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)
        else:
            critic_net = critic_network.CriticNetwork(
                (observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers)
            tf_agent = sac_agent.SacAgent(
                time_step_spec,
                action_spec,
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=critic_learning_rate),
                alpha_optimizer=tf.keras.optimizers.Adam(
                    learning_rate=alpha_learning_rate),
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)

        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
            tf_py_metric.TFPyMetric(
                metrics.AverageEarlyFailureMetric(
                    buffer_size=num_eval_episodes,
                    batch_size=tf_env.batch_size))
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        config_saver = gin.tf.GinConfigSaverHook(train_dir,
                                                 summarize_config=True)
        tf.function(config_saver.after_create_session)()

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        tf.estimator.SessionRunHook()
        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        if not rb_checkpointer.checkpoint_exists:
            initial_collect_driver.run()

        if run_eval:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics',
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        while global_step.numpy() < num_iterations:
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0 and run_eval:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
Beispiel #7
0
    # Sample a batch of data from the buffer and update the agent's network.
    iterator = iter(dataset)
    experience, unused_info = next(iterator)

    train_loss = agent.train(experience).loss

    log_interval = 5
    eval_interval = 5

    if step % log_interval == 0:
        print('step = {0}: loss = {1}'.format(step, train_loss))

    if step % 150 == 0 and step >= 9000:
        policy_dir = os.path.join('waypoints', 'DOUBLE REWARDS_POLICY')
        tf_policy_saver = policy_saver.PolicySaver(agent.policy)
        tf_policy_saver.save(policy_dir)

        checkpoint_dir = os.path.join('waypoints', 'Double rewards_CP')
        train_checkpointer = common.Checkpointer(
            ckpt_dir=checkpoint_dir,
            max_to_keep=1,
            agent=agent,
            policy=agent.policy,
            replay_buffer=replay_buffer,
        )

plt.plot(returns)
plt.grid()
plt.show()
Beispiel #8
0
def train_eval(
    root_dir,
    offline_dir=None,
    random_seed=None,
    env_name='sawyer_push',
    eval_env_name=None,
    env_load_fn=get_env,
    max_episode_steps=1000,
    eval_episode_steps=1000,
    # The SAC paper reported:
    # Hopper and Cartpole results up to 1000000 iters,
    # Humanoid results up to 10000000 iters,
    # Other mujoco tasks up to 3000000 iters.
    num_iterations=3000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
    # HalfCheetah and Ant take 10000 initial collection steps.
    # Other mujoco tasks take 1000.
    # Different choices roughly keep the initial episodes about the same.
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    reset_goal_frequency=1000,  # virtual episode size for reset-free training
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    # reset-free parameters
    use_minimum=True,
    reset_lagrange_learning_rate=3e-4,
    value_threshold=None,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    # Td3 parameters
    actor_update_period=1,
    exploration_noise_std=0.1,
    target_policy_noise=0.1,
    target_policy_noise_clip=0.1,
    dqda_clipping=None,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    # video recording for the environment
    video_record_interval=10000,
    num_videos=0,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  start_time = time.time()

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')
  video_dir = os.path.join(eval_dir, 'videos')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    if random_seed is not None:
      tf.compat.v1.set_random_seed(random_seed)

    if FLAGS.use_reset_goals in [-1]:
      gym_env_wrappers = (functools.partial(
          reset_free_wrapper.GoalTerminalResetWrapper,
          num_success_states=FLAGS.num_success_states,
          full_reset_frequency=max_episode_steps),)
    elif FLAGS.use_reset_goals in [0, 1]:
      gym_env_wrappers = (functools.partial(
          reset_free_wrapper.ResetFreeWrapper,
          reset_goal_frequency=reset_goal_frequency,
          variable_horizon_for_reset=FLAGS.variable_reset_horizon,
          num_success_states=FLAGS.num_success_states,
          full_reset_frequency=max_episode_steps),)
    elif FLAGS.use_reset_goals in [2]:
      gym_env_wrappers = (functools.partial(
          reset_free_wrapper.CustomOracleResetWrapper,
          partial_reset_frequency=reset_goal_frequency,
          episodes_before_full_reset=max_episode_steps //
          reset_goal_frequency),)
    elif FLAGS.use_reset_goals in [3, 4]:
      gym_env_wrappers = (functools.partial(
          reset_free_wrapper.GoalTerminalResetFreeWrapper,
          reset_goal_frequency=reset_goal_frequency,
          num_success_states=FLAGS.num_success_states,
          full_reset_frequency=max_episode_steps),)
    elif FLAGS.use_reset_goals in [5, 7]:
      gym_env_wrappers = (functools.partial(
          reset_free_wrapper.CustomOracleResetGoalTerminalWrapper,
          partial_reset_frequency=reset_goal_frequency,
          episodes_before_full_reset=max_episode_steps //
          reset_goal_frequency),)
    elif FLAGS.use_reset_goals in [6]:
      gym_env_wrappers = (functools.partial(
          reset_free_wrapper.VariableGoalTerminalResetWrapper,
          full_reset_frequency=max_episode_steps),)

    if env_name == 'playpen_reduced':
      train_env_load_fn = functools.partial(
          env_load_fn, reset_at_goal=FLAGS.reset_at_goal)
    else:
      train_env_load_fn = env_load_fn

    env, env_train_metrics, env_eval_metrics, aux_info = train_env_load_fn(
        name=env_name,
        max_episode_steps=None,
        gym_env_wrappers=gym_env_wrappers)

    tf_env = tf_py_environment.TFPyEnvironment(env)
    eval_env_name = eval_env_name or env_name
    eval_tf_env = tf_py_environment.TFPyEnvironment(
        env_load_fn(name=eval_env_name,
                    max_episode_steps=eval_episode_steps)[0])

    eval_metrics += env_eval_metrics

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    if FLAGS.agent_type == 'sac':
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          observation_spec,
          action_spec,
          fc_layer_params=actor_fc_layers,
          continuous_projection_net=functools.partial(
              tanh_normal_projection_network.TanhNormalProjectionNetwork,
              std_transform=std_clip_transform))
      critic_net = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform',
      )

      critic_net_no_entropy = None
      critic_no_entropy_optimizer = None
      if FLAGS.use_no_entropy_q:
        critic_net_no_entropy = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform',
            name='CriticNetworkNoEntropy1')
        critic_no_entropy_optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate)

      tf_agent = SacAgent(
          time_step_spec,
          action_spec,
          num_action_samples=FLAGS.num_action_samples,
          actor_network=actor_net,
          critic_network=critic_net,
          critic_network_no_entropy=critic_net_no_entropy,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          critic_no_entropy_optimizer=critic_no_entropy_optimizer,
          alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=alpha_learning_rate),
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step)

    elif FLAGS.agent_type == 'td3':
      actor_net = actor_network.ActorNetwork(
          tf_env.time_step_spec().observation,
          tf_env.action_spec(),
          fc_layer_params=actor_fc_layers,
      )
      critic_net = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform')

      tf_agent = Td3Agent(
          tf_env.time_step_spec(),
          tf_env.action_spec(),
          actor_network=actor_net,
          critic_network=critic_net,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          exploration_noise_std=exploration_noise_std,
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          actor_update_period=actor_update_period,
          dqda_clipping=dqda_clipping,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          target_policy_noise=target_policy_noise,
          target_policy_noise_clip=target_policy_noise_clip,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
      )

    tf_agent.initialize()

    if FLAGS.use_reset_goals > 0:
      if FLAGS.use_reset_goals in [4, 5, 6]:
        reset_goal_generator = ScheduledResetGoal(
            goal_dim=aux_info['reset_state_shape'][0],
            num_success_for_switch=FLAGS.num_success_for_switch,
            num_chunks=FLAGS.num_chunks,
            name='ScheduledResetGoalGenerator')
      else:
        # distance to initial state distribution
        initial_state_distance = state_distribution_distance.L2Distance(
            initial_state_shape=aux_info['reset_state_shape'])
        initial_state_distance.update(
            tf.constant(aux_info['reset_states'], dtype=tf.float32),
            update_type='complete')

        if use_tf_functions:
          initial_state_distance.distance = common.function(
              initial_state_distance.distance)
          tf_agent.compute_value = common.function(tf_agent.compute_value)

        # initialize reset / practice goal proposer
        if reset_lagrange_learning_rate > 0:
          reset_goal_generator = ResetGoalGenerator(
              goal_dim=aux_info['reset_state_shape'][0],
              compute_value_fn=tf_agent.compute_value,
              distance_fn=initial_state_distance,
              use_minimum=use_minimum,
              value_threshold=value_threshold,
              lagrange_variable_max=FLAGS.lagrange_max,
              optimizer=tf.compat.v1.train.AdamOptimizer(
                  learning_rate=reset_lagrange_learning_rate),
              name='reset_goal_generator')
        else:
          reset_goal_generator = FixedResetGoal(
              distance_fn=initial_state_distance)

      # if use_tf_functions:
      #   reset_goal_generator.get_reset_goal = common.function(
      #       reset_goal_generator.get_reset_goal)

      # modify the reset-free wrapper to use the reset goal generator
      tf_env.pyenv.envs[0].set_reset_goal_fn(
          reset_goal_generator.get_reset_goal)

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    if FLAGS.relabel_goals:
      cur_episode_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
          data_spec=tf_agent.collect_data_spec,
          batch_size=1,
          scope='CurEpisodeReplayBuffer',
          max_length=int(2 * min(reset_goal_frequency, max_episode_steps)))

      # NOTE: the buffer is replaced because cannot have two buffers.add_batch
      replay_observer = [cur_episode_buffer.add_batch]

    # initialize metrics and observers
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
    ]

    train_metrics += env_train_metrics

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    eval_py_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_agent.policy, use_tf_function=True)

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)
    if use_tf_functions:
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    if offline_dir is not None:
      offline_data = tf_uniform_replay_buffer.TFUniformReplayBuffer(
          data_spec=tf_agent.collect_data_spec,
          batch_size=1,
          max_length=int(1e5))  # this has to be 100_000
      offline_checkpointer = common.Checkpointer(
          ckpt_dir=offline_dir, max_to_keep=1, replay_buffer=offline_data)
      offline_checkpointer.initialize_or_restore()

      # set the reset candidates to be all the data in offline buffer
      if (FLAGS.use_reset_goals > 0 and
          reset_lagrange_learning_rate > 0) or FLAGS.use_reset_goals in [
              4, 5, 6, 7
          ]:
        tf_env.pyenv.envs[0].set_reset_candidates(
            nest_utils.unbatch_nested_tensors(offline_data.gather_all()))

    if replay_buffer.num_frames() == 0:
      if offline_dir is not None:
        copy_replay_buffer(offline_data, replay_buffer)
        print(replay_buffer.num_frames())

        # multiply offline data
        if FLAGS.relabel_offline_data:
          data_multiplier(replay_buffer,
                          tf_env.pyenv.envs[0].env.compute_reward)
          print('after data multiplication:', replay_buffer.num_frames())

      initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
          tf_env,
          initial_collect_policy,
          observers=replay_observer + train_metrics,
          num_steps=1)
      if use_tf_functions:
        initial_collect_driver.run = common.function(initial_collect_driver.run)

      # Collect initial replay data.
      logging.info(
          'Initializing replay buffer by collecting experience for %d steps with '
          'a random policy.', initial_collect_steps)

      time_step = None
      policy_state = collect_policy.get_initial_state(tf_env.batch_size)

      for iter_idx in range(initial_collect_steps):
        time_step, policy_state = initial_collect_driver.run(
            time_step=time_step, policy_state=policy_state)

        if time_step.is_last() and FLAGS.relabel_goals:
          reward_fn = tf_env.pyenv.envs[0].env.compute_reward
          relabel_function(cur_episode_buffer, time_step, reward_fn,
                           replay_buffer)
          cur_episode_buffer.clear()

        if FLAGS.use_reset_goals > 0 and time_step.is_last(
        ) and FLAGS.num_reset_candidates > 0:
          tf_env.pyenv.envs[0].set_reset_candidates(
              replay_buffer.get_next(
                  sample_batch_size=FLAGS.num_reset_candidates)[0])

    else:
      time_step = None
      policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      return ~trajectories.is_boundary()[0]

    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size, num_steps=2).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)

    if use_tf_functions:
      train_step = common.function(train_step)

    # manual data save for plotting utils
    np_custom_save(os.path.join(eval_dir, 'eval_interval.npy'), eval_interval)
    try:
      average_eval_return = np_custom_load(
          os.path.join(eval_dir, 'average_eval_return.npy')).tolist()
      average_eval_success = np_custom_load(
          os.path.join(eval_dir, 'average_eval_success.npy')).tolist()
      average_eval_final_success = np_custom_load(
          os.path.join(eval_dir, 'average_eval_final_success.npy')).tolist()
    except:  # pylint: disable=bare-except
      average_eval_return = []
      average_eval_success = []
      average_eval_final_success = []

    print('initialization_time:', time.time() - start_time)
    for iter_idx in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )

      if time_step.is_last() and FLAGS.relabel_goals:
        reward_fn = tf_env.pyenv.envs[0].env.compute_reward
        relabel_function(cur_episode_buffer, time_step, reward_fn,
                         replay_buffer)
        cur_episode_buffer.clear()

      # reset goal generator updates
      if FLAGS.use_reset_goals > 0 and iter_idx % (
          FLAGS.reset_goal_frequency * collect_steps_per_iteration) == 0:
        if FLAGS.num_reset_candidates > 0:
          tf_env.pyenv.envs[0].set_reset_candidates(
              replay_buffer.get_next(
                  sample_batch_size=FLAGS.num_reset_candidates)[0])
        if reset_lagrange_learning_rate > 0:
          reset_goal_generator.update_lagrange_multipliers()

      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
      time_acc += time.time() - start_time

      global_step_val = global_step.numpy()

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, train_loss.loss)
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step_val
        time_acc = 0

      for train_metric in train_metrics:
        if 'Heatmap' in train_metric.name:
          if global_step_val % summary_interval == 0:
            train_metric.tf_summaries(
                train_step=global_step, step_metrics=train_metrics[:2])
        else:
          train_metric.tf_summaries(
              train_step=global_step, step_metrics=train_metrics[:2])

      if global_step_val % summary_interval == 0 and FLAGS.use_reset_goals > 0 and reset_lagrange_learning_rate > 0:
        reset_states, values, initial_state_distance_vals, lagrangian = reset_goal_generator.update_summaries(
            step_counter=global_step)
        for vf_viz_metric in aux_info['value_fn_viz_metrics']:
          vf_viz_metric.tf_summaries(
              reset_states,
              values,
              train_step=global_step,
              step_metrics=train_metrics[:2])

        if FLAGS.debug_value_fn_for_reset:
          num_test_lagrange = 20
          hyp_lagranges = [
              1.0 * increment / num_test_lagrange
              for increment in range(num_test_lagrange + 1)
          ]

          door_pos = reset_states[
              np.argmin(initial_state_distance_vals.numpy() -
                        lagrangian.numpy() * values.numpy())][3:5]
          print('cur lagrange: %.2f, cur reset goal: (%.2f, %.2f)' %
                (lagrangian.numpy(), door_pos[0], door_pos[1]))
          for lagrange in hyp_lagranges:
            door_pos = reset_states[
                np.argmin(initial_state_distance_vals.numpy() -
                          lagrange * values.numpy())][3:5]
            print('test lagrange: %.2f, cur reset goal: (%.2f, %.2f)' %
                  (lagrange, door_pos[0], door_pos[1]))
          print('\n')

      if global_step_val % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step_val)
        metric_utils.log_metrics(eval_metrics)

        # numpy saves for plotting
        if 'AverageReturn' in results.keys():
          average_eval_return.append(results['AverageReturn'].numpy())
        if 'EvalSuccessfulAtAnyStep' in results.keys():
          average_eval_success.append(
              results['EvalSuccessfulAtAnyStep'].numpy())
        if 'EvalSuccessfulEpisodes' in results.keys():
          average_eval_final_success.append(
              results['EvalSuccessfulEpisodes'].numpy())
        elif 'EvalSuccessfulAtLastStep' in results.keys():
          average_eval_final_success.append(
              results['EvalSuccessfulAtLastStep'].numpy())

        if average_eval_return:
          np_custom_save(
              os.path.join(eval_dir, 'average_eval_return.npy'),
              average_eval_return)
        if average_eval_success:
          np_custom_save(
              os.path.join(eval_dir, 'average_eval_success.npy'),
              average_eval_success)
        if average_eval_final_success:
          np_custom_save(
              os.path.join(eval_dir, 'average_eval_final_success.npy'),
              average_eval_final_success)

      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)

      if global_step_val % video_record_interval == 0:
        for video_idx in range(num_videos):
          video_name = os.path.join(video_dir, str(global_step_val),
                                    'video_' + str(video_idx) + '.mp4')
          record_video(
              lambda: env_load_fn(  # pylint: disable=g-long-lambda
                  name=env_name,
                  max_episode_steps=max_episode_steps)[0],
              video_name,
              eval_py_policy,
              max_episode_length=eval_episode_steps)

    return train_loss
def main(cfg):
    # Set up logging and checkpointing
    log_dir = Path(cfg.log_dir)
    checkpoint_dir = Path(cfg.checkpoint_dir)
    print('log_dir: {}'.format(log_dir))
    print('checkpoint_dir: {}'.format(checkpoint_dir))

    # Create env
    env = utils.get_env_from_cfg(cfg)
    tf_env = components.get_tf_py_env(env, cfg.num_input_channels)

    # Agents
    epsilon = tf.Variable(1.0)
    agents = []
    for i, g in enumerate(cfg.robot_config):
        robot_type = next(iter(g))
        q_net = components.QNetwork(
            tf_env.observation_spec(),
            num_output_channels=VectorEnv.get_num_output_channels(robot_type))
        optimizer = keras.optimizers.SGD(
            learning_rate=cfg.learning_rate,
            momentum=0.9)  # cfg.weight_decay is currently ignored
        agent_cls = dqn_agent.DdqnAgent if cfg.use_double_dqn else dqn_agent.DqnAgent
        agent = agent_cls(
            time_step_spec=tf_env.time_step_spec(),
            action_spec=components.get_action_spec(robot_type),
            q_network=q_net,
            optimizer=optimizer,
            epsilon_greedy=epsilon,
            target_update_period=(cfg.target_update_freq // cfg.train_freq),
            td_errors_loss_fn=common.element_wise_huber_loss,
            gamma=cfg.discount_factors[i],
            gradient_clipping=cfg.grad_norm_clipping,
            train_step_counter=tf.Variable(
                0, dtype=tf.int64),  # Separate counter for each agent
        )
        agent.initialize()
        agent.train = common.function(agent.train)
        agents.append(agent)
    global_step = agents[0].train_step_counter

    # Replay buffers
    replay_buffers = [ReplayBuffer(cfg.replay_buffer_size) for _ in agents]

    # Checkpointing
    timestep_var = tf.Variable(0, dtype=tf.int64)
    agent_checkpointer = common.Checkpointer(ckpt_dir=str(checkpoint_dir /
                                                          'agents'),
                                             max_to_keep=5,
                                             agents=agents,
                                             timestep_var=timestep_var)
    agent_checkpointer.initialize_or_restore()
    if timestep_var.numpy() > 0:
        checkpoint_path = checkpoint_dir / 'checkpoint_{:08d}.pkl'.format(
            timestep_var.numpy())
        with open(checkpoint_path, 'rb') as f:
            replay_buffers = pickle.load(f)

    # Logging
    train_summary_writer = tf.summary.create_file_writer(str(log_dir /
                                                             'train'))
    train_summary_writer.set_as_default()

    time_step = tf_env.reset()
    learning_starts = round(cfg.learning_starts_frac * cfg.total_timesteps)
    total_timesteps_with_warm_up = learning_starts + cfg.total_timesteps
    start_timestep = timestep_var.numpy()
    for timestep in tqdm(range(start_timestep, total_timesteps_with_warm_up),
                         initial=start_timestep,
                         total=total_timesteps_with_warm_up,
                         file=sys.stdout):
        # Set exploration epsilon
        exploration_eps = 1 - (1 - cfg.final_exploration) * min(
            1,
            max(0, timestep - learning_starts) /
            (cfg.exploration_frac * cfg.total_timesteps))
        epsilon.assign(exploration_eps)

        # Run one collect step
        transitions_per_buffer = tf_env.pyenv.envs[0].store_time_step(
            time_step)
        robot_group_index = tf_env.pyenv.envs[0].current_robot_group_index()
        action_step = agents[robot_group_index].collect_policy.action(
            time_step)
        time_step = tf_env.step(action_step.action)

        # Store experience in buffers
        for i, transitions in enumerate(transitions_per_buffer):
            for transition in transitions:
                replay_buffers[i].push(*transition)

        # Train policies
        if timestep >= learning_starts and (timestep +
                                            1) % cfg.train_freq == 0:
            for i, agent in enumerate(agents):
                experience = replay_buffers[i].sample(cfg.batch_size)
                agent.train(experience)

        # Logging
        if tf_env.pyenv.envs[0].done():
            info = tf_env.pyenv.envs[0].get_info()
            tf.summary.scalar('timesteps', timestep + 1, global_step)
            tf.summary.scalar('steps', info['steps'], global_step)
            tf.summary.scalar('total_cubes', info['total_cubes'], global_step)

        # Checkpointing
        if (
                timestep + 1
        ) % cfg.checkpoint_freq == 0 or timestep + 1 == total_timesteps_with_warm_up:
            # Save agents
            timestep_var.assign(timestep + 1)
            agent_checkpointer.save(timestep + 1)

            # Save replay buffers
            checkpoint_path = checkpoint_dir / 'checkpoint_{:08d}.pkl'.format(
                timestep + 1)
            with open(checkpoint_path, 'wb') as f:
                pickle.dump(replay_buffers, f)
            cfg.checkpoint_path = str(checkpoint_path)
            utils.save_config(log_dir / 'config.yml', cfg)

            # Remove old checkpoints
            checkpoint_paths = list(checkpoint_dir.glob('checkpoint_*.pkl'))
            checkpoint_paths.remove(checkpoint_path)
            for old_checkpoint_path in checkpoint_paths:
                old_checkpoint_path.unlink()

    # Export trained policies
    policy_dir = checkpoint_dir / 'policies'
    for i, agent in enumerate(agents):
        policy_saver.PolicySaver(agent.policy).save(
            str(policy_dir / 'robot_group_{:02}'.format(i + 1)))
    cfg.policy_path = str(policy_dir)
    utils.save_config(log_dir / 'config.yml', cfg)

    env.close()
Beispiel #10
0
def load_agents_and_create_videos(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=NUM_ITERATIONS,
        max_ep_steps=1000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=((100, )),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        num_atoms=51,
        min_q_value=-20,
        max_q_value=20,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        num_random_episodes=1,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None,
        random_metrics_callback=None):

    # Define the directories to read from
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    random_dir = os.path.join(root_dir, 'random')

    # Match the writers and metrics used in training
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)

    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    random_summary_writer = tf.compat.v2.summary.create_file_writer(
        random_dir, flush_millis=summaries_flush_secs * 1000)

    random_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    # Match the environments used in training
    tf_env = tf_py_environment.TFPyEnvironment(
        suite_gym.load(env_name, max_episode_steps=max_ep_steps))
    eval_py_env = suite_gym.load(env_name, max_episode_steps=max_ep_steps)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    # Match the agents used in training
    categorical_q_net = categorical_q_network.CategoricalQNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        num_atoms=num_atoms,
        fc_layer_params=fc_layer_params)

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    tf_agent = categorical_dqn_agent.CategoricalDqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        categorical_q_network=categorical_q_net,
        optimizer=optimizer,
        min_q_value=min_q_value,
        max_q_value=max_q_value,
        n_step_update=n_step_update,
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        train_step_counter=global_step)

    tf_agent.initialize()

    train_metrics = [
        # tf_metrics.NumberOfEpisodes(),
        # tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration)

    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=tf_agent,
                                             global_step=global_step,
                                             metrics=metric_utils.MetricsGroup(
                                                 train_metrics,
                                                 'train_metrics'))

    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=eval_policy,
                                              global_step=global_step)

    rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'replay_buffer'),
                                          max_to_keep=1,
                                          replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    if use_tf_functions:
        # To speed up collect use common.function.
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

    random_policy = random_tf_policy.RandomTFPolicy(
        eval_tf_env.time_step_spec(), eval_tf_env.action_spec())

    # Make movies of the trained agent and a random agent
    date_string = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')

    # Finally, used the saved policy to generate the video
    trained_filename = "trainedC51_" + date_string
    create_policy_eval_video(eval_tf_env, eval_py_env, tf_agent.policy,
                             trained_filename)

    # And, create one with a random agent for comparison
    random_filename = 'random_' + date_string
    create_policy_eval_video(eval_tf_env, eval_py_env, random_policy,
                             random_filename)
Beispiel #11
0
    def learn(self, num_iterations=100000):
        dataset = self.replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=self.batch_size,
            num_steps=2).prefetch(3)

        iterator = iter(dataset)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            self.train_env,
            self.tf_agent.collect_policy,
            observers=self.replay_observer + self.train_metrics,
            num_steps=self.collect_steps_per_iteration)

        root_dir = self.log_dir
        root_dir = os.path.expanduser(root_dir)
        train_dir = os.path.join(root_dir, 'train')
        eval_dir = os.path.join(root_dir, 'eval')
        checkpoint_dir = os.path.join(root_dir, 'checkpoint')
        policy_dir = os.path.join(root_dir, 'policy')
        best_policy_dir = os.path.join(root_dir, 'best_policy')

        saver_policy = policy_saver.PolicySaver(self.tf_agent.policy)
        train_checkpointer = common.Checkpointer(
            ckpt_dir=checkpoint_dir,
            max_to_keep=2,
            agent=self.tf_agent,
            policy=self.tf_agent.policy,
            replay_buffer=self.replay_buffer,
            global_step=self.global_step)

        if (not self.resume_training):
            train_summary_writer = tf.summary.create_file_writer(
                train_dir, flush_millis=self.summaries_flush_secs * 1000)
            train_summary_writer.set_as_default()
            # Reset the train step
            self.tf_agent.train_step_counter.assign(0)
        else:
            train_checkpointer.initialze_or_restore()
            self.global_step = tf.compat.v1.train.get_global_step()
            print('Resume global step: ', self.global_step)

        # (Optional) Optimize by wrapping some of the code in a graph using TF function.
        self.tf_agent.train = common.function(self.tf_agent.train)
        collect_driver.run = common.function(collect_driver.run)

        with tf.summary.record_if(lambda: tf.math.equal(
                self.global_step % self.summary_interval, 0)):
            for _ in tqdm(range(num_iterations)):
                collect_driver.run()

                experience, unused_info = next(iterator)
                train_loss = self.tf_agent.train(experience)

                for train_metric in self.train_metrics:
                    train_metric.tf_summaries(train_step=self.global_step,
                                              step_metrics=self.step_metrics)

                step = self.tf_agent.train_step_counter.numpy()

                if ((step % 100000) == 0):
                    train_checkpointer.save(self.global_step)
                    saver_policy.save(policy_dir)

                if (step % self.summary_interval == 0):
                    avg = self.compute_avg_return(self.num_eval_episodes)
                    tf.summary.scalar('Average Reward',
                                      avg,
                                      step=self.global_step)

                    if (avg > self.return_avg):
                        saver_policy.save(best_policy_dir)
                        self.return_avg = avg
def train_eval(
    root_dir,
    environment_name="broken_reacher",
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    initial_collect_steps=10000,
    real_initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    real_collect_interval=10,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    classifier_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=True,
    summarize_grads_and_vars=False,
    train_on_real=False,
    delta_r_warmup=0,
    random_seed=0,
    checkpoint_dir=None,
):
  """A simple train and eval for SAC."""
  np.random.seed(random_seed)
  tf.random.set_seed(random_seed)
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, "train")
  eval_dir = os.path.join(root_dir, "eval")

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)

  if environment_name == "broken_reacher":
    get_env_fn = darc_envs.get_broken_reacher_env
  elif environment_name == "half_cheetah_obstacle":
    get_env_fn = darc_envs.get_half_cheetah_direction_env
  elif environment_name.startswith("broken_joint"):
    base_name = environment_name.split("broken_joint_")[1]
    get_env_fn = functools.partial(
        darc_envs.get_broken_joint_env, env_name=base_name)
  elif environment_name.startswith("falling"):
    base_name = environment_name.split("falling_")[1]
    get_env_fn = functools.partial(
        darc_envs.get_falling_env, env_name=base_name)
  else:
    raise NotImplementedError("Unknown environment: %s" % environment_name)

  eval_name_list = ["sim", "real"]
  eval_env_list = [get_env_fn(mode) for mode in eval_name_list]

  eval_metrics_list = []
  for name in eval_name_list:
    eval_metrics_list.append([
        tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, name="AverageReturn_%s" % name),
    ])

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env_real = get_env_fn("real")
    if train_on_real:
      tf_env = get_env_fn("real")
    else:
      tf_env = get_env_fn("sim")

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=(
            tanh_normal_projection_network.TanhNormalProjectionNetwork),
    )
    critic_net = critic_network.CriticNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        kernel_initializer="glorot_uniform",
        last_kernel_initializer="glorot_uniform",
    )

    classifier = classifiers.build_classifier(observation_spec, action_spec)

    tf_agent = darc_agent.DarcAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        classifier=classifier,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        classifier_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=classifier_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step,
    )
    tf_agent.initialize()

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity,
    )
    replay_observer = [replay_buffer.add_batch]

    real_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity,
    )
    real_replay_observer = [real_replay_buffer.add_batch]

    sim_train_metrics = [
        tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesSim"),
        tf_metrics.EnvironmentSteps(name="EnvironmentStepsSim"),
        tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size,
            name="AverageReturnSim",
        ),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size,
            name="AverageEpisodeLengthSim",
        ),
    ]
    real_train_metrics = [
        tf_metrics.NumberOfEpisodes(name="NumberOfEpisodesReal"),
        tf_metrics.EnvironmentSteps(name="EnvironmentStepsReal"),
        tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size,
            name="AverageReturnReal",
        ),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size,
            name="AverageEpisodeLengthReal",
        ),
    ]

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(
            sim_train_metrics + real_train_metrics, "train_metrics"),
    )
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, "policy"),
        policy=eval_policy,
        global_step=global_step,
    )
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, "replay_buffer"),
        max_to_keep=1,
        replay_buffer=(replay_buffer, real_replay_buffer),
    )

    if checkpoint_dir is not None:
      checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
      assert checkpoint_path is not None
      train_checkpointer._load_status = train_checkpointer._checkpoint.restore(   # pylint: disable=protected-access
          checkpoint_path)
      train_checkpointer._load_status.initialize_or_restore()  # pylint: disable=protected-access
    else:
      train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    if replay_buffer.num_frames() == 0:
      initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
          tf_env,
          initial_collect_policy,
          observers=replay_observer + sim_train_metrics,
          num_steps=initial_collect_steps,
      )
      real_initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
          tf_env_real,
          initial_collect_policy,
          observers=real_replay_observer + real_train_metrics,
          num_steps=real_initial_collect_steps,
      )

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + sim_train_metrics,
        num_steps=collect_steps_per_iteration,
    )

    real_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env_real,
        collect_policy,
        observers=real_replay_observer + real_train_metrics,
        num_steps=collect_steps_per_iteration,
    )

    config_str = gin.operative_config_str()
    logging.info(config_str)
    with tf.compat.v1.gfile.Open(os.path.join(root_dir, "operative.gin"),
                                 "w") as f:
      f.write(config_str)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      real_initial_collect_driver.run = common.function(
          real_initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      real_collect_driver.run = common.function(real_collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    if replay_buffer.num_frames() == 0:
      logging.info(
          "Initializing replay buffer by collecting experience for %d steps with "
          "a random policy.",
          initial_collect_steps,
      )
      initial_collect_driver.run()
      real_initial_collect_driver.run()

    for eval_name, eval_env, eval_metrics in zip(eval_name_list, eval_env_list,
                                                 eval_metrics_list):
      metric_utils.eager_compute(
          eval_metrics,
          eval_env,
          eval_policy,
          num_episodes=num_eval_episodes,
          train_step=global_step,
          summary_writer=eval_summary_writer,
          summary_prefix="Metrics-%s" % eval_name,
      )
      metric_utils.log_metrics(eval_metrics)

    time_step = None
    real_time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      return ~trajectories.is_boundary()[0]

    dataset = (
        replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))
    real_dataset = (
        real_replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5))

    # Dataset generates trajectories with shape [Bx2x...]
    iterator = iter(dataset)
    real_iterator = iter(real_dataset)

    def train_step():
      experience, _ = next(iterator)
      real_experience, _ = next(real_iterator)
      return tf_agent.train(experience, real_experience=real_experience)

    if use_tf_functions:
      train_step = common.function(train_step)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      assert not policy_state  # We expect policy_state == ().
      if (global_step.numpy() % real_collect_interval == 0 and
          global_step.numpy() >= delta_r_warmup):
        real_time_step, policy_state = real_collect_driver.run(
            time_step=real_time_step,
            policy_state=policy_state,
        )

      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
      time_acc += time.time() - start_time

      global_step_val = global_step.numpy()

      if global_step_val % log_interval == 0:
        logging.info("step = %d, loss = %f", global_step_val, train_loss.loss)
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
        logging.info("%.3f steps/sec", steps_per_sec)
        tf.compat.v2.summary.scalar(
            name="global_steps_per_sec", data=steps_per_sec, step=global_step)
        timed_at_step = global_step_val
        time_acc = 0

      for train_metric in sim_train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=sim_train_metrics[:2])
      for train_metric in real_train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=real_train_metrics[:2])

      if global_step_val % eval_interval == 0:
        for eval_name, eval_env, eval_metrics in zip(eval_name_list,
                                                     eval_env_list,
                                                     eval_metrics_list):
          metric_utils.eager_compute(
              eval_metrics,
              eval_env,
              eval_policy,
              num_episodes=num_eval_episodes,
              train_step=global_step,
              summary_writer=eval_summary_writer,
              summary_prefix="Metrics-%s" % eval_name,
          )
          metric_utils.log_metrics(eval_metrics)

      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
    return train_loss
Beispiel #13
0
    print(tf_agent.collect_data_spec)
    print('Replay Buffer Created, start warming-up ...')
    _startTime = dt.datetime.now()

    # driver for warm-up
    # https://www.tensorflow.org/agents/api_docs/python/tf_agents/drivers/dynamic_episode_driver/DynamicEpisodeDriver
    initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        env,
        collect_policy,
        observers=[replay_buffer.add_batch],
        num_episodes=warmupEpisodes)
    # run restore process
    if shouldContinueFromLastCheckpoint:
        train_checkpointer = common.Checkpointer(ckpt_dir=checkpointDir,
                                                 max_to_keep=1,
                                                 agent=tf_agent,
                                                 policy=tf_agent.policy,
                                                 replay_buffer=replay_buffer,
                                                 global_step=global_step)
        train_checkpointer.initialize_or_restore()
    else:
        initial_collect_driver.run()
    _timeCost = (dt.datetime.now() - _startTime).total_seconds()
    print('Replay Buffer Warm-up Done. (cost {:.3g} hours)'.format(_timeCost /
                                                                   3600.0))
    _startTime = dt.datetime.now()

    # Training

    print('Prepare for training ...')
    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        env,
Beispiel #14
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        eval_env_name=None,
        env_load_fn=suite_mujoco.load,
        # The SAC paper reported:
        # Hopper and Cartpole results up to 1000000 iters,
        # Humanoid results up to 10000000 iters,
        # Other mujoco tasks up to 3000000 iters.
        num_iterations=3000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
        # HalfCheetah and Ant take 10000 initial collection steps.
        # Other mujoco tasks take 1000.
        # Different choices roughly keep the initial episodes about the same.
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create the environment.
        tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        eval_env_name = eval_env_name or env_name
        eval_py_env = env_load_fn(eval_env_name)

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        eval_py_policy = py_tf_policy.PyTFPolicy(
            greedy_policy.GreedyPolicy(tf_agent.policy))

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()
        train_op = tf_agent.train(trajectories)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        with tf.compat.v1.Session() as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            # Initialize training.
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            global_step_val = sess.run(global_step)

            if global_step_val == 0:
                # Initial eval of randomly initialized policy
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    callback=eval_metrics_callback,
                    log=True,
                )
                sess.run(eval_summary_flush_op)

                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    total_loss, _ = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_flush_op)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
Beispiel #15
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        num_iterations=100000,
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        # Params for collect
        initial_collect_steps=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=200,
        batch_size=64,
        train_sequence_length=10,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        gamma=0.995,
        reward_scale_factor=1.0,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=10000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []
        environment = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)
        tf_env = tf_py_environment.TFPyEnvironment(environment)
        eval_py_env = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        actor_net = actor_rnn_network.ActorRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers)

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_rnn_network.CriticRnnNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
        )

        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            debug_summaries=debug_summaries,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # TODO(oars): Refactor drivers to better handle policy states. Remove the
        # policy reset and passing down an empyt policy state to the driver.
        collect_policy = tf_agent.collect_policy
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run(policy_state=policy_state)

        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run(
                policy_state=policy_state)

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        trajectories, unused_info = iterator.get_next()
        train_op = tf_agent.train(experience=trajectories)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy,
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(train_op)
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
Beispiel #16
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        log_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # Note this is a python environment.
    env = batched_py_environment.BatchedPyEnvironment(
        [suite_gym.load(env_name)])
    eval_py_env = suite_gym.load(env_name)

    # Convert specs to BoundedTensorSpec.
    action_spec = tensor_spec.from_spec(env.action_spec())
    observation_spec = tensor_spec.from_spec(env.observation_spec())
    time_step_spec = ts.time_step_spec(observation_spec)

    q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()),
                               tensor_spec.from_spec(env.action_spec()),
                               fc_layer_params=fc_layer_params)

    # The agent must be in graph.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    agent = dqn_agent.DqnAgent(
        time_step_spec,
        action_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy)
    greedy_policy = py_tf_policy.PyTFPolicy(agent.policy)
    random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(),
                                                    env.action_spec())

    # Python replay buffer.
    replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
        capacity=replay_buffer_capacity,
        data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec))

    time_step = env.reset()

    # Initialize the replay buffer with some transitions. We use the random
    # policy to initialize the replay buffer to make sure we get a good
    # distribution of actions.
    for _ in range(initial_collect_steps):
        time_step = collect_step(env, time_step, random_policy, replay_buffer)

    # TODO(b/112041045) Use global_step as counter.
    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=agent,
                                             global_step=global_step)

    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=agent.policy,
                                              global_step=global_step)

    ds = replay_buffer.as_dataset(sample_batch_size=batch_size,
                                  num_steps=n_step_update + 1)
    ds = ds.prefetch(4)
    itr = tf.compat.v1.data.make_initializable_iterator(ds)

    experience = itr.get_next()

    train_op = common.function(agent.train)(experience)

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
        for eval_metric in eval_metrics:
            eval_metric.tf_summaries(train_step=global_step)

    with tf.compat.v1.Session() as session:
        train_checkpointer.initialize_or_restore(session)
        common.initialize_uninitialized_variables(session)
        session.run(itr.initializer)
        # Copy critic network values to the target critic network.
        session.run(agent.initialize())
        train = session.make_callable(train_op)
        global_step_call = session.make_callable(global_step)
        session.run(train_summary_writer.init())
        session.run(eval_summary_writer.init())

        # Compute initial evaluation metrics.
        global_step_val = global_step_call()
        metric_utils.compute_summaries(
            eval_metrics,
            eval_py_env,
            greedy_policy,
            num_episodes=num_eval_episodes,
            global_step=global_step_val,
            log=True,
            callback=eval_metrics_callback,
        )

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0
        steps_per_second_ph = tf.compat.v1.placeholder(tf.float32,
                                                       shape=(),
                                                       name='steps_per_sec_ph')
        steps_per_second_summary = tf.compat.v2.summary.scalar(
            name='global_steps_per_sec',
            data=steps_per_second_ph,
            step=global_step)

        for _ in range(num_iterations):
            start_time = time.time()
            for _ in range(collect_steps_per_iteration):
                time_step = collect_step(env, time_step, collect_policy,
                                         replay_buffer)
            collect_time += time.time() - start_time
            start_time = time.time()
            for _ in range(train_steps_per_iteration):
                loss = train()
            train_time += time.time() - start_time
            global_step_val = global_step_call()
            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             loss.loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                session.run(steps_per_second_summary,
                            feed_dict={steps_per_second_ph: steps_per_sec})
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info(
                    '%s', 'collect_time = {}, train_time = {}'.format(
                        collect_time, train_time))
                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % eval_interval == 0:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    greedy_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    log=True,
                    callback=eval_metrics_callback,
                )
                # Reset timing to avoid counting eval time.
                timed_at_step = global_step_val
                start_time = time.time()
Beispiel #17
0
def train_eval(
        root_dir,
        tf_master='',
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=None,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
        batched_py_metric.BatchedPyMetric(
            AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            kl_cutoff_factor=0.0,
            initial_adaptive_kl_beta=0.0,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session(tf_master) as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss, _ = sess.run([train_op, summary_ops])
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)
  def testLossLearnerDifferentDistStrat(self, create_agent_fn):
    # Create the strategies used in the test. The second value is the per-core
    # batch size.
    bs_multiplier = 4
    strategies = {
        'default': (tf.distribute.get_strategy(), 4 * bs_multiplier),
        'one_device':
            (tf.distribute.OneDeviceStrategy('/cpu:0'), 4 * bs_multiplier),
        'mirrored': (tf.distribute.MirroredStrategy(), 1 * bs_multiplier),
    }
    if tf.config.list_logical_devices('TPU'):
      strategies['TPU'] = (_get_tpu_strategy(), 2 * bs_multiplier)
    else:
      logging.info('TPU hardware is not available, TPU strategy test skipped.')

    learners = {
        name: self._build_learner_with_strategy(create_agent_fn, strategy,
                                                per_core_batch_size)
        for name, (strategy, per_core_batch_size) in strategies.items()
    }

    # Verify that the initial variable values in the learners are the same.
    default_strat_trainer, _, default_vars, _, _ = learners['default']
    for name, (trainer, _, variables, _, _) in learners.items():
      if name != 'default':
        self._assign_variables(default_strat_trainer, trainer)
        self.assertLen(variables, len(default_vars))
        for default_variable, variable in zip(default_vars, variables):
          self.assertAllEqual(default_variable, variable)

    # Calculate losses.
    losses = {}
    checkpoint_path = {}
    iterations = 1
    optimizer_variables = {}
    for name, (trainer, _, variables, train_step, _) in learners.items():
      old_vars = self.evaluate(variables)

      loss = trainer.run(iterations=iterations).loss
      logging.info('Using strategy: %s, the loss is: %s at train step: %s',
                   name, loss, train_step)

      new_vars = self.evaluate(variables)
      losses[name] = old_vars, loss, new_vars
      self.assertNotEmpty(trainer._agent._optimizer.variables())
      optimizer_variables[name] = trainer._agent._optimizer.variables()
      checkpoint_path[name] = trainer._checkpointer.manager.directory

    for name, path in checkpoint_path.items():
      logging.info('Checkpoint dir for learner %s: %s. Content: %s', name, path,
                   tf.io.gfile.listdir(path))
      checkpointer = common.Checkpointer(path)

      # Make sure that the checkpoint file exists, so the learner initialized
      # using the corresponding root directory will pick up the values in the
      # checkpoint file.
      self.assertTrue(checkpointer.checkpoint_exists)

      # Create a learner using an existing root directory containing the
      # checkpoint files.
      strategy, per_core_batch_size = strategies[name]
      learner_from_checkpoint = self._build_learner_with_strategy(
          create_agent_fn,
          strategy,
          per_core_batch_size,
          root_dir=os.path.join(path, '..', '..'))[0]

      # Check if the learner was in fact created based on the an existing
      # checkpoint.
      self.assertTrue(learner_from_checkpoint._checkpointer.checkpoint_exists)

      # Check if the values of the variables of the learner initialized from
      # checkpoint that are the same as the values were used to write the
      # checkpoint.
      original_learner = learners[name][0]
      self.assertAllClose(
          learner_from_checkpoint._agent.collect_policy.variables(),
          original_learner._agent.collect_policy.variables())
      self.assertAllClose(learner_from_checkpoint._agent._optimizer.variables(),
                          original_learner._agent._optimizer.variables())

    # Verify same dataset across learner calls.
    for item in tf.data.Dataset.zip(tuple([v[1] for v in learners.values()])):
      for i in range(1, len(item)):
        # Compare default strategy obervation to the other datasets, second
        # index is getting the trajectory from (trajectory, sample_info) tuple.
        self.assertAllEqual(item[0][0].observation, item[i][0].observation)

    # Check that the losses are close to each other.
    _, default_loss, _ = losses['default']
    for name, (_, loss, _) in losses.items():
      self._compare_losses(loss, default_loss, delta=1.e-2)

    # Check that the optimizer variables are close to each other.
    default_optimizer_vars = optimizer_variables['default']
    for name, optimizer_vars in optimizer_variables.items():
      self.assertAllClose(
          optimizer_vars,
          default_optimizer_vars,
          atol=1.e-2,
          rtol=1.e-2,
          msg=('The initial values of the optimizer variables for the strategy '
               '{} are significantly different from the initial values of the '
               'default strategy.').format(name))

    # Check that the variables changed after calling `learner.run`.
    for old_vars, _, new_vars in losses.values():
      dist_test_utils.check_variables_different(self, old_vars, new_vars)
Beispiel #19
0
def train_eval(
    root_dir,
    env_name='MultiGrid-Empty-5x5-v0',
    env_load_fn=multiagent_gym_suite.load,
    random_seed=0,
    # Architecture params
    agent_class=multiagent_ppo.MultiagentPPO,
    actor_fc_layers=(64, 64),
    value_fc_layers=(64, 64),
    lstm_size=(64,),
    conv_filters=64,
    conv_kernel=3,
    direction_fc=5,
    entropy_regularization=0.,
    use_attention_networks=False,
    # Specialized agents
    inactive_agent_ids=tuple(),
    # Params for collect
    num_environment_steps=25000000,
    collect_episodes_per_iteration=30,
    num_parallel_environments=5,
    replay_buffer_capacity=1001,  # Per-environment
    # Params for train
    num_epochs=2,
    learning_rate=1e-4,
    # Params for eval
    num_eval_episodes=2,
    eval_interval=5,
    # Params for summaries and logging
    train_checkpoint_interval=100,
    policy_checkpoint_interval=100,
    log_interval=10,
    summary_interval=10,
    summaries_flush_secs=1,
    use_tf_functions=True,
    debug_summaries=True,
    summarize_grads_and_vars=True,
    eval_metrics_callback=None,
    reinit_checkpoint_dir=None,
    debug=True):
  """A simple train and eval for PPO."""
  tf.compat.v1.enable_v2_behavior()

  if root_dir is None:
    raise AttributeError('train_eval requires a root_dir.')

  if debug:
    logging.info('In debug mode, turning tf_functions off')
    use_tf_functions = False

  for a in inactive_agent_ids:
    logging.info('Fixing and not training agent %d', a)

  # Load multiagent gym environment and determine number of agents
  gym_env = env_load_fn(env_name)
  n_agents = gym_env.n_agents

  # Set up logging
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')
  saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      multiagent_metrics.AverageReturnMetric(
          n_agents, buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    if random_seed is not None:
      tf.compat.v1.set_random_seed(random_seed)

    logging.info('Creating %d environments...', num_parallel_environments)
    wrappers = []
    if use_attention_networks:
      wrappers = [lambda env: utils.LSTMStateWrapper(env, lstm_size=lstm_size)]

    eval_tf_env = tf_py_environment.TFPyEnvironment(
        env_load_fn(
            env_name,
            gym_kwargs=dict(seed=random_seed),
            gym_env_wrappers=wrappers))
    # pylint: disable=g-complex-comprehension
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment([
            functools.partial(env_load_fn, environment_name=env_name,
                              gym_env_wrappers=wrappers,
                              gym_kwargs=dict(seed=random_seed * 1234 + i))
            for i in range(num_parallel_environments)
        ]))


    logging.info('Preparing to train...')
    environment_steps_metric = tf_metrics.EnvironmentSteps()
    step_metrics = [
        tf_metrics.NumberOfEpisodes(),
        environment_steps_metric,
    ]
    bonus_metrics = [
        multiagent_metrics.MultiagentScalar(
            n_agents, name='UnscaledMultiagentBonus', buffer_size=1000),
    ]
    train_metrics = step_metrics + [
        multiagent_metrics.AverageReturnMetric(
            n_agents, batch_size=num_parallel_environments),
        tf_metrics.AverageEpisodeLengthMetric(
            batch_size=num_parallel_environments),
    ]

    logging.info('Creating agent...')
    tf_agent = agent_class(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        n_agents=n_agents,
        learning_rate=learning_rate,
        actor_fc_layers=actor_fc_layers,
        value_fc_layers=value_fc_layers,
        lstm_size=lstm_size,
        conv_filters=conv_filters,
        conv_kernel=conv_kernel,
        direction_fc=direction_fc,
        entropy_regularization=entropy_regularization,
        num_epochs=num_epochs,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step,
        inactive_agent_ids=inactive_agent_ids)
    tf_agent.initialize()
    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    logging.info('Allocating replay buffer ...')
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,
        max_length=replay_buffer_capacity)
    logging.info('RB capacity: %i', replay_buffer.capacity)

    # If reinit_checkpoint_dir is provided, the last agent in the checkpoint is
    # reinitialized. The other agents are novices.
    # Otherwise, all agents are reinitialized from train_dir.
    if reinit_checkpoint_dir:
      reinit_checkpointer = common.Checkpointer(
          ckpt_dir=reinit_checkpoint_dir,
          agent=tf_agent,
      )
      reinit_checkpointer.initialize_or_restore()
      temp_dir = os.path.join(train_dir, 'tmp')
      agent_checkpointer = common.Checkpointer(
          ckpt_dir=temp_dir,
          agent=tf_agent.agents[:-1],
      )
      agent_checkpointer.save(global_step=0)
      tf_agent = agent_class(
          tf_env.time_step_spec(),
          tf_env.action_spec(),
          n_agents=n_agents,
          learning_rate=learning_rate,
          actor_fc_layers=actor_fc_layers,
          value_fc_layers=value_fc_layers,
          lstm_size=lstm_size,
          conv_filters=conv_filters,
          conv_kernel=conv_kernel,
          direction_fc=direction_fc,
          entropy_regularization=entropy_regularization,
          num_epochs=num_epochs,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
          inactive_agent_ids=inactive_agent_ids,
          non_learning_agents=list(range(n_agents - 1)))
      agent_checkpointer = common.Checkpointer(
          ckpt_dir=temp_dir, agent=tf_agent.agents[:-1])
      agent_checkpointer.initialize_or_restore()
      tf.io.gfile.rmtree(temp_dir)
      eval_policy = tf_agent.policy
      collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=multiagent_metrics.MultiagentMetricsGroup(
            train_metrics + bonus_metrics, 'train_metrics'))
    if not reinit_checkpoint_dir:
      train_checkpointer.initialize_or_restore()
    logging.info('Successfully initialized train checkpointer')

    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step)

    collect_policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=collect_policy,
        global_step=global_step)
    collect_saved_model = policy_saver.PolicySaver(
        collect_policy, train_step=global_step)

    logging.info('Successfully initialized policy saver.')

    print('Using TFDriver')
    if use_attention_networks:
      collect_driver = utils.StateTFDriver(
          tf_env,
          collect_policy,
          observers=[replay_buffer.add_batch] + train_metrics,
          max_episodes=collect_episodes_per_iteration,
          disable_tf_function=not use_tf_functions)
    else:
      collect_driver = tf_driver.TFDriver(
          tf_env,
          collect_policy,
          observers=[replay_buffer.add_batch] + train_metrics,
          max_episodes=collect_episodes_per_iteration,
          disable_tf_function=not use_tf_functions)

    def train_step():
      trajectories = replay_buffer.gather_all()
      return tf_agent.train(experience=trajectories)

    if use_tf_functions:
      tf_agent.train = common.function(tf_agent.train, autograph=False)
      train_step = common.function(train_step)

    collect_time = 0
    train_time = 0
    timed_at_step = global_step.numpy()

    # How many consecutive steps was loss diverged for.
    loss_divergence_counter = 0

    # Save operative config as late as possible to include used configurables.
    if global_step.numpy() == 0:
      config_filename = os.path.join(
          train_dir, 'operative_config-{}.gin'.format(global_step.numpy()))
      with tf.io.gfile.GFile(config_filename, 'wb') as f:
        f.write(gin.operative_config_str())

    total_episodes = 0
    logging.info('Commencing train loop!')
    while environment_steps_metric.result() < num_environment_steps:
      global_step_val = global_step.numpy()

      # Evaluation
      if global_step_val % eval_interval == 0:
        if debug:
          logging.info('Performing evaluation at step %d', global_step_val)
        results = multiagent_metrics.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
            use_function=use_tf_functions,
            use_attention_networks=use_attention_networks
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step.numpy())
        multiagent_metrics.log_metrics(eval_metrics)

      # Collect data
      if debug:
        logging.info('Collecting at step %d', global_step_val)
      start_time = time.time()
      time_step = tf_env.reset()
      policy_state = collect_policy.get_initial_state(tf_env.batch_size)
      if use_attention_networks:
        # Attention networks require previous policy state to compute attention
        # weights.
        time_step.observation['policy_state'] = (
            policy_state['actor_network_state'][0],
            policy_state['actor_network_state'][1])
      collect_driver.run(time_step, policy_state)
      collect_time += time.time() - start_time

      total_episodes += collect_episodes_per_iteration
      if debug:
        logging.info('Have collected a total of %d episodes', total_episodes)

      # Train
      if debug:
        logging.info('Training at step %d', global_step_val)
      start_time = time.time()
      total_loss, extra_loss = train_step()
      replay_buffer.clear()
      train_time += time.time() - start_time

      # Check for exploding losses.
      if (math.isnan(total_loss) or math.isinf(total_loss) or
          total_loss > MAX_LOSS):
        loss_divergence_counter += 1
        if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
          logging.info('Loss diverged for too many timesteps, breaking...')
          break
      else:
        loss_divergence_counter = 0

      for train_metric in train_metrics + bonus_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=step_metrics)

      if global_step_val % log_interval == 0:
        logging.info('step = %d, total loss = %f', global_step_val, total_loss)
        for a in range(n_agents):
          if not inactive_agent_ids or a not in inactive_agent_ids:
            logging.info('Loss for agent %d = %f', a, extra_loss[a].loss)
        steps_per_sec = ((global_step_val - timed_at_step) /
                         (collect_time + train_time))
        logging.info('%.3f steps/sec', steps_per_sec)
        logging.info('collect_time = %.3f, train_time = %.3f', collect_time,
                     train_time)
        with tf.compat.v2.summary.record_if(True):
          tf.compat.v2.summary.scalar(
              name='global_steps_per_sec', data=steps_per_sec, step=global_step)

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0

      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)
        saved_model_path = os.path.join(
            saved_model_dir, 'policy_' + ('%d' % global_step_val).zfill(9))
        saved_model.save(saved_model_path)
        collect_policy_checkpointer.save(global_step=global_step_val)
        collect_saved_model_path = os.path.join(
            saved_model_dir,
            'collect_policy_' + ('%d' % global_step_val).zfill(9))
        collect_saved_model.save(collect_saved_model_path)

    # One final eval before exiting.
    results = multiagent_metrics.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
        use_function=use_tf_functions,
        use_attention_networks=use_attention_networks
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    multiagent_metrics.log_metrics(eval_metrics)
Beispiel #20
0
                               observers=[replay_buffer.add_batch],
                               num_steps=STEPS_PER_ITER)
    # Wrap the run function in a TF graph
    driver.run = common.function(driver.run)
    # Create driver for the random policy
    random_driver = DynamicStepDriver(env=train_env,
                                      policy=random_policy,
                                      observers=[replay_buffer.add_batch],
                                      num_steps=STEPS_PER_ITER)
    # Wrap the run function in a TF graph
    random_driver.run = common.function(random_driver.run)

    # Create a checkpointer
    checkpointer = common.Checkpointer(ckpt_dir=os.path.relpath('checkpoint'),
                                       max_to_keep=1,
                                       agent=agent,
                                       policy=agent.policy,
                                       replay_buffer=replay_buffer,
                                       global_step=global_step)
    checkpointer.initialize_or_restore()
    global_step = tf.compat.v1.train.get_global_step()

    # Create a policy saver
    policy_saver = PolicySaver(agent.policy)

    # Main training loop
    time_step, policy_state = None, None
    for it in range(N_ITERATIONS):
        if COLLECT_RANDOM:
            print('Running random driver...')
            time_step, policy_state = random_driver.run(time_step, policy_state)
        print('Running agent driver...')
    def __init__(
            self,
            root_dir,
            env_name,
            num_iterations=200,
            max_episode_frames=108000,  # ALE frames
            terminal_on_life_loss=False,
            conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3),
                                                                  1)),
            fc_layer_params=(512, ),
            # Params for collect
            initial_collect_steps=80000,  # ALE frames
            epsilon_greedy=0.01,
            epsilon_decay_period=1000000,  # ALE frames
            replay_buffer_capacity=1000000,
            # Params for train
            train_steps_per_iteration=1000000,  # ALE frames
            update_period=16,  # ALE frames
            target_update_tau=1.0,
            target_update_period=32000,  # ALE frames
            batch_size=32,
            learning_rate=2.5e-4,
            gamma=0.99,
            reward_scale_factor=1.0,
            gradient_clipping=None,
            # Params for eval
            do_eval=True,
            eval_steps_per_iteration=500000,  # ALE frames
            eval_epsilon_greedy=0.001,
            # Params for checkpoints, summaries, and logging
            log_interval=1000,
            summary_interval=1000,
            summaries_flush_secs=10,
            debug_summaries=False,
            summarize_grads_and_vars=False,
            eval_metrics_callback=None):
        """A simple Atari train and eval for DQN.

    Args:
      root_dir: Directory to write log files to.
      env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0).
      num_iterations: Number of train/eval iterations to run.
      max_episode_frames: Maximum length of a single episode, in ALE frames.
      terminal_on_life_loss: Whether to simulate an episode termination when a
        life is lost.
      conv_layer_params: Params for convolutional layers of QNetwork.
      fc_layer_params: Params for fully connected layers of QNetwork.
      initial_collect_steps: Number of frames to ALE frames to process before
        beginning to train. Since this is in ALE frames, there will be
        initial_collect_steps/4 items in the RB when training starts.
      epsilon_greedy: Final epsilon value to decay to for training.
      epsilon_decay_period: Period over which to decay epsilon, from 1.0 to
        epsilon_greedy (defined above).
      replay_buffer_capacity: Maximum number of items to store in the RB.
      train_steps_per_iteration: Number of ALE frames to run through for each
        iteration of training.
      update_period: Run a train operation every update_period ALE frames.
      target_update_tau: Coeffecient for soft target network updates (1.0 ==
        hard updates).
      target_update_period: Period, in ALE frames, to copy the live network to
        the target network.
      batch_size: Number of frames to include in each training batch.
      learning_rate: RMS optimizer learning rate.
      gamma: Discount for future rewards.
      reward_scale_factor: Scaling factor for rewards.
      gradient_clipping: Norm length to clip gradients.
      do_eval: If True, run an eval every iteration. If False, skip eval.
      eval_steps_per_iteration: Number of ALE frames to run through for each
        iteration of training.
      eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 ==
        totally greedy policy).
      log_interval: Log stats to the terminal every log_interval training
        steps.
      summary_interval: Write TF summaries every summary_interval training
        steps.
      summaries_flush_secs: Flush summaries to disk every summaries_flush_secs
        seconds.
      debug_summaries: If True, write additional summaries for debugging (see
        dqn_agent for which summaries are written).
      summarize_grads_and_vars: Include gradients in summaries.
      eval_metrics_callback: A callback function that takes (metric_dict,
        global_step) as parameters. Called after every eval with the results of
        the evaluation.
    """
        self._update_period = update_period / ATARI_FRAME_SKIP
        self._train_steps_per_iteration = (train_steps_per_iteration /
                                           ATARI_FRAME_SKIP)
        self._do_eval = do_eval
        self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP
        self._eval_epsilon_greedy = eval_epsilon_greedy
        self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP
        self._summary_interval = summary_interval
        self._num_iterations = num_iterations
        self._log_interval = log_interval
        self._eval_metrics_callback = eval_metrics_callback

        with gin.unlock_config():
            gin.bind_parameter('AtariPreprocessing.terminal_on_life_loss',
                               terminal_on_life_loss)

        root_dir = os.path.expanduser(root_dir)
        train_dir = os.path.join(root_dir, 'train')
        eval_dir = os.path.join(root_dir, 'eval')

        train_summary_writer = tf.compat.v2.summary.create_file_writer(
            train_dir, flush_millis=summaries_flush_secs * 1000)
        train_summary_writer.set_as_default()
        self._train_summary_writer = train_summary_writer

        self._eval_summary_writer = None
        if self._do_eval:
            self._eval_summary_writer = tf.compat.v2.summary.create_file_writer(
                eval_dir, flush_millis=summaries_flush_secs * 1000)
            self._eval_metrics = [
                py_metrics.AverageReturnMetric(name='PhaseAverageReturn',
                                               buffer_size=np.inf),
                py_metrics.AverageEpisodeLengthMetric(
                    name='PhaseAverageEpisodeLength', buffer_size=np.inf),
            ]

        self._global_step = tf.compat.v1.train.get_or_create_global_step()
        with tf.compat.v2.summary.record_if(lambda: tf.math.equal(
                self._global_step % self._summary_interval, 0)):
            self._env = suite_atari.load(
                env_name,
                max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP,
                gym_env_wrappers=suite_atari.
                DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)
            self._env = batched_py_environment.BatchedPyEnvironment(
                [self._env])

            observation_spec = tensor_spec.from_spec(
                self._env.observation_spec())
            time_step_spec = ts.time_step_spec(observation_spec)
            action_spec = tensor_spec.from_spec(self._env.action_spec())

            with tf.device('/cpu:0'):
                epsilon = tf.compat.v1.train.polynomial_decay(
                    1.0,
                    self._global_step,
                    epsilon_decay_period / ATARI_FRAME_SKIP /
                    self._update_period,
                    end_learning_rate=epsilon_greedy)

            with tf.device('/gpu:0'):
                optimizer = tf.compat.v1.train.RMSPropOptimizer(
                    learning_rate=learning_rate,
                    decay=0.95,
                    momentum=0.0,
                    epsilon=0.00001,
                    centered=True)
                q_net = AtariQNetwork(observation_spec,
                                      action_spec,
                                      conv_layer_params=conv_layer_params,
                                      fc_layer_params=fc_layer_params)
                tf_agent = dqn_agent.DqnAgent(
                    time_step_spec,
                    action_spec,
                    q_network=q_net,
                    optimizer=optimizer,
                    epsilon_greedy=epsilon,
                    target_update_tau=target_update_tau,
                    target_update_period=(target_update_period /
                                          ATARI_FRAME_SKIP /
                                          self._update_period),
                    td_errors_loss_fn=dqn_agent.element_wise_huber_loss,
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=self._global_step)

                self._collect_policy = py_tf_policy.PyTFPolicy(
                    tf_agent.collect_policy)

                if self._do_eval:
                    self._eval_policy = py_tf_policy.PyTFPolicy(
                        epsilon_greedy_policy.EpsilonGreedyPolicy(
                            policy=tf_agent.policy,
                            epsilon=self._eval_epsilon_greedy))

                py_observation_spec = self._env.observation_spec()
                py_time_step_spec = ts.time_step_spec(py_observation_spec)
                py_action_spec = policy_step.PolicyStep(
                    self._env.action_spec())
                data_spec = trajectory.from_transition(py_time_step_spec,
                                                       py_action_spec,
                                                       py_time_step_spec)
                self._replay_buffer = (
                    py_hashed_replay_buffer.PyHashedReplayBuffer(
                        data_spec=data_spec, capacity=replay_buffer_capacity))

            with tf.device('/cpu:0'):
                ds = self._replay_buffer.as_dataset(
                    sample_batch_size=batch_size, num_steps=2).prefetch(4)
                ds = ds.apply(
                    tf.data.experimental.prefetch_to_device('/gpu:0'))

            with tf.device('/gpu:0'):
                self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds)
                experience = self._ds_itr.get_next()
                self._train_op = tf_agent.train(experience)

                self._env_steps_metric = py_metrics.EnvironmentSteps()
                self._step_metrics = [
                    py_metrics.NumberOfEpisodes(),
                    self._env_steps_metric,
                ]
                self._train_metrics = self._step_metrics + [
                    py_metrics.AverageReturnMetric(buffer_size=10),
                    py_metrics.AverageEpisodeLengthMetric(buffer_size=10),
                ]
                # The _train_phase_metrics average over an entire train iteration,
                # rather than the rolling average of the last 10 episodes.
                self._train_phase_metrics = [
                    py_metrics.AverageReturnMetric(name='PhaseAverageReturn',
                                                   buffer_size=np.inf),
                    py_metrics.AverageEpisodeLengthMetric(
                        name='PhaseAverageEpisodeLength', buffer_size=np.inf),
                ]
                self._iteration_metric = py_metrics.CounterMetric(
                    name='Iteration')

                # Summaries written from python should run every time they are
                # generated.
                with tf.compat.v2.summary.record_if(True):
                    self._steps_per_second_ph = tf.compat.v1.placeholder(
                        tf.float32, shape=(), name='steps_per_sec_ph')
                    self._steps_per_second_summary = tf.compat.v2.summary.scalar(
                        name='global_steps_per_sec',
                        data=self._steps_per_second_ph,
                        step=self._global_step)

                    for metric in self._train_metrics:
                        metric.tf_summaries(train_step=self._global_step,
                                            step_metrics=self._step_metrics)

                    for metric in self._train_phase_metrics:
                        metric.tf_summaries(
                            train_step=self._global_step,
                            step_metrics=(self._iteration_metric, ))
                    self._iteration_metric.tf_summaries(
                        train_step=self._global_step)

                    if self._do_eval:
                        with self._eval_summary_writer.as_default():
                            for metric in self._eval_metrics:
                                metric.tf_summaries(
                                    train_step=self._global_step,
                                    step_metrics=(self._iteration_metric, ))

                self._train_checkpointer = common.Checkpointer(
                    ckpt_dir=train_dir,
                    agent=tf_agent,
                    global_step=self._global_step,
                    optimizer=optimizer,
                    metrics=metric_utils.MetricsGroup(
                        self._train_metrics + self._train_phase_metrics +
                        [self._iteration_metric], 'train_metrics'))
                self._policy_checkpointer = common.Checkpointer(
                    ckpt_dir=os.path.join(train_dir, 'policy'),
                    policy=tf_agent.policy,
                    global_step=self._global_step)
                self._rb_checkpointer = common.Checkpointer(
                    ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
                    max_to_keep=1,
                    replay_buffer=self._replay_buffer)

                self._init_agent_op = tf_agent.initialize()
Beispiel #22
0
def main(_):
    # setting up
    start_time = time.time()
    tf.compat.v1.enable_resource_variables()
    tf.compat.v1.disable_eager_execution()
    logging.set_verbosity(logging.INFO)
    global observation_omit_size, goal_coord, sample_count, iter_count, episode_size_buffer, episode_return_buffer

    root_dir = os.path.abspath(os.path.expanduser(FLAGS.logdir))
    if not tf.io.gfile.exists(root_dir):
        tf.io.gfile.makedirs(root_dir)
    log_dir = os.path.join(root_dir, FLAGS.environment)

    if not tf.io.gfile.exists(log_dir):
        tf.io.gfile.makedirs(log_dir)
    save_dir = os.path.join(log_dir, 'models')
    if not tf.io.gfile.exists(save_dir):
        tf.io.gfile.makedirs(save_dir)

    print('directory for recording experiment data:', log_dir)

    # in case training is paused and resumed, so can be restored
    try:
        sample_count = np.load(os.path.join(log_dir,
                                            'sample_count.npy')).tolist()
        iter_count = np.load(os.path.join(log_dir, 'iter_count.npy')).tolist()
        episode_size_buffer = np.load(
            os.path.join(log_dir, 'episode_size_buffer.npy')).tolist()
        episode_return_buffer = np.load(
            os.path.join(log_dir, 'episode_return_buffer.npy')).tolist()
    except:
        sample_count = 0
        iter_count = 0
        episode_size_buffer = []
        episode_return_buffer = []

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        os.path.join(log_dir, 'train', 'in_graph_data'),
        flush_millis=10 * 1000)
    train_summary_writer.set_as_default()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(True):
        # environment related stuff
        env = do.get_environment(env_name=FLAGS.environment)
        py_env = wrap_env(skill_wrapper.SkillWrapper(
            env,
            num_latent_skills=FLAGS.num_skills,
            skill_type=FLAGS.skill_type,
            preset_skill=None,
            min_steps_before_resample=FLAGS.min_steps_before_resample,
            resample_prob=FLAGS.resample_prob),
                          max_episode_steps=FLAGS.max_env_steps)

        # all specifications required for all networks and agents
        py_action_spec = py_env.action_spec()
        tf_action_spec = tensor_spec.from_spec(
            py_action_spec)  # policy, critic action spec
        env_obs_spec = py_env.observation_spec()
        py_env_time_step_spec = ts.time_step_spec(
            env_obs_spec)  # replay buffer time_step spec
        if observation_omit_size > 0:
            agent_obs_spec = array_spec.BoundedArraySpec(
                (env_obs_spec.shape[0] - observation_omit_size, ),
                env_obs_spec.dtype,
                minimum=env_obs_spec.minimum,
                maximum=env_obs_spec.maximum,
                name=env_obs_spec.name)  # policy, critic observation spec
        else:
            agent_obs_spec = env_obs_spec
        py_agent_time_step_spec = ts.time_step_spec(
            agent_obs_spec)  # policy, critic time_step spec
        tf_agent_time_step_spec = tensor_spec.from_spec(
            py_agent_time_step_spec)

        if not FLAGS.reduced_observation:
            skill_dynamics_observation_size = (
                py_env_time_step_spec.observation.shape[0] - FLAGS.num_skills)
        else:
            skill_dynamics_observation_size = FLAGS.reduced_observation

        # TODO(architsh): Shift co-ordinate hiding to actor_net and critic_net (good for futher image based processing as well)
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_agent_time_step_spec.observation,
            tf_action_spec,
            fc_layer_params=(FLAGS.hidden_layer_size, ) * 2,
            continuous_projection_net=do._normal_projection_net)

        critic_net = critic_network.CriticNetwork(
            (tf_agent_time_step_spec.observation, tf_action_spec),
            observation_fc_layer_params=None,
            action_fc_layer_params=None,
            joint_fc_layer_params=(FLAGS.hidden_layer_size, ) * 2)

        if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0:
            reweigh_batches_flag = True
        else:
            reweigh_batches_flag = False

        agent = dads_agent.DADSAgent(
            # DADS parameters
            save_dir,
            skill_dynamics_observation_size,
            observation_modify_fn=do.process_observation,
            restrict_input_size=observation_omit_size,
            latent_size=FLAGS.num_skills,
            latent_prior=FLAGS.skill_type,
            prior_samples=FLAGS.random_skills,
            fc_layer_params=(FLAGS.hidden_layer_size, ) * 2,
            normalize_observations=FLAGS.normalize_data,
            network_type=FLAGS.graph_type,
            num_mixture_components=FLAGS.num_components,
            fix_variance=FLAGS.fix_variance,
            reweigh_batches=reweigh_batches_flag,
            skill_dynamics_learning_rate=FLAGS.skill_dynamics_lr,
            # SAC parameters
            time_step_spec=tf_agent_time_step_spec,
            action_spec=tf_action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            target_update_tau=0.005,
            target_update_period=1,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=FLAGS.agent_lr),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=FLAGS.agent_lr),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=FLAGS.agent_lr),
            td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
            gamma=FLAGS.agent_gamma,
            reward_scale_factor=1. / (FLAGS.agent_entropy + 1e-12),
            gradient_clipping=None,
            debug_summaries=FLAGS.debug,
            train_step_counter=global_step)

        # evaluation policy
        eval_policy = py_tf_policy.PyTFPolicy(agent.policy)

        # collection policy
        if FLAGS.collect_policy == 'default':
            collect_policy = py_tf_policy.PyTFPolicy(agent.collect_policy)
        elif FLAGS.collect_policy == 'ou_noise':
            collect_policy = py_tf_policy.PyTFPolicy(
                ou_noise_policy.OUNoisePolicy(agent.collect_policy,
                                              ou_stddev=0.2,
                                              ou_damping=0.15))

        # relabelling policy deals with batches of data, unlike collect and eval
        relabel_policy = py_tf_policy.PyTFPolicy(agent.collect_policy)

        # constructing a replay buffer, need a python spec
        policy_step_spec = policy_step.PolicyStep(action=py_action_spec,
                                                  state=(),
                                                  info=())

        if FLAGS.skill_dynamics_relabel_type is not None and 'importance_sampling' in FLAGS.skill_dynamics_relabel_type and FLAGS.is_clip_eps > 1.0:
            policy_step_spec = policy_step_spec._replace(
                info=policy_step.set_log_probability(
                    policy_step_spec.info,
                    array_spec.ArraySpec(
                        shape=(), dtype=np.float32, name='action_log_prob')))

        trajectory_spec = from_transition(py_env_time_step_spec,
                                          policy_step_spec,
                                          py_env_time_step_spec)
        capacity = FLAGS.replay_buffer_capacity
        # for all the data collected
        rbuffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
            capacity=capacity, data_spec=trajectory_spec)

        if FLAGS.train_skill_dynamics_on_policy:
            # for on-policy data (if something special is required)
            on_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
                capacity=FLAGS.initial_collect_steps + FLAGS.collect_steps +
                10,
                data_spec=trajectory_spec)

        # insert experience manually with relabelled rewards and skills
        agent.build_agent_graph()
        agent.build_skill_dynamics_graph()
        agent.create_savers()

        # saving this way requires the saver to be out the object
        train_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            save_dir, 'agent'),
                                                 agent=agent,
                                                 global_step=global_step)
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            save_dir, 'policy'),
                                                  policy=agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            save_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=rbuffer)

        setup_time = time.time() - start_time
        print('Setup time:', setup_time)

        with tf.compat.v1.Session().as_default() as sess:
            eval_policy.session = sess
            eval_policy.initialize(None)
            eval_policy.restore(os.path.join(FLAGS.logdir, 'models', 'policy'))

            plotdir = os.path.join(FLAGS.logdir, "plots")
            if not os.path.exists(plotdir):
                os.mkdir(plotdir)
            do.FLAGS = FLAGS
            do.eval_loop(eval_dir=plotdir,
                         eval_policy=eval_policy,
                         plot_name="plot")
Beispiel #23
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=None,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)
        eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_clip_agent.PPOClipAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
Beispiel #24
0
def train_eval(
        root_dir,
        env_name='gym_solventx-v0',
        num_iterations=100000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        gym_env = gym.make(env_name, config_file=config_file)
        py_env = suite_gym.wrap_env(gym_env, max_episode_steps=100)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_gym_env = gym.make(env_name, config_file=config_file)
        eval_py_env = suite_gym.wrap_env(eval_gym_env, max_episode_steps=100)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

        #tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        #eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name), config_file=config_file)

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        if train_sequence_length > 1:
            q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)
        else:
            q_net = q_network.QNetwork(tf_env.observation_spec(),
                                       tf_env.action_spec(),
                                       fc_layer_params=fc_layer_params)
            train_sequence_length = n_step_update

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())
                saved_model_path = os.path.join(
                    saved_model_dir,
                    'policy_' + ('%d' % global_step.numpy()).zfill(9))
                saved_model.save(saved_model_path)

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        return train_loss
Beispiel #25
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        eval_env_name=None,
        env_load_fn=suite_mujoco.load,
        # The SAC paper reported:
        # Hopper and Cartpole results up to 1000000 iters,
        # Humanoid results up to 10000000 iters,
        # Other mujoco tasks up to 3000000 iters.
        num_iterations=3000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
        # HalfCheetah and Ant take 10000 initial collection steps.
        # Other mujoco tasks take 1000.
        # Different choices roughly keep the initial episodes about the same.
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=50000,
        policy_checkpoint_interval=50000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        if replay_buffer.num_frames() == 0:
            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps '
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        global_step_val = global_step.numpy()
        while global_step_val < num_iterations:
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step_val)
                metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
Beispiel #26
0
def train(
        root_dir,
        agent,
        environment,
        training_loops,
        steps_per_loop=1,
        additional_metrics=(),
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10,
        policy_checkpoint_interval=10,
        log_interval=10,
        summary_interval=10):
    """A training driver."""

    if not common.resource_variables_enabled():
        raise RuntimeError(common.MISSING_RESOURCE_VARIABLES_ERROR)

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(train_dir)
    train_summary_writer.set_as_default()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(batch_size=environment.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=environment.batch_size),
        ] + list(additional_metrics)

        # Add to replay buffer and other agent specific observers.
        replay_buffer = build_replay_buffer(agent, environment.batch_size,
                                            steps_per_loop)
        agent_observers = [replay_buffer.add_batch] + train_metrics

        driver = dynamic_step_driver.DynamicStepDriver(
            env=environment,
            policy=agent.policy,
            num_steps=steps_per_loop * environment.batch_size,
            observers=agent_observers)

        collect_op, _ = driver.run()
        batch_size = driver.env.batch_size
        dataset = replay_buffer.as_dataset(sample_batch_size=batch_size,
                                           num_steps=steps_per_loop,
                                           single_deterministic_pass=True)
        trajectories, unused_info = tf.data.experimental.get_single_element(
            dataset)
        train_op = agent.train(experience=trajectories)
        clear_replay_op = replay_buffer.clear()

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            max_to_keep=1,
            agent=agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  max_to_keep=None,
                                                  policy=agent.policy,
                                                  global_step=global_step)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        init_agent_op = agent.initialize()

        config_saver = utils.GinConfigSaverHook(train_dir,
                                                summarize_config=True)
        config_saver.begin()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            common.initialize_uninitialized_variables(sess)

            config_saver.after_create_session(sess)

            global_step_call = sess.make_callable(global_step)
            global_step_val = global_step_call()

            sess.run(train_summary_writer.init())
            sess.run(collect_op)

            if global_step_val == 0:
                # Save an initial checkpoint so the evaluator runs for global_step=0.
                policy_checkpointer.save(global_step=global_step_val)
                sess.run(init_agent_op)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            clear_replay_call = sess.make_callable(clear_replay_op)

            timed_at_step = global_step_val
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(training_loops):
                # Collect and train.
                start_time = time.time()
                collect_call()
                total_loss, _ = train_step_call()
                clear_replay_call()
                global_step_val = global_step_call()

                time_acc += time.time() - start_time

                total_loss = total_loss.loss

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
Beispiel #27
0
driver = dynamic_step_driver.DynamicStepDriver(
    train_env,
    agent.collect_policy,
    observers=[replay_buffer.add_batch, metric],
    num_steps=collect_steps_per_iteration)
# Initial data collection
driver.run()
# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, sample_batch_size=batch_size,
    num_steps=2, single_deterministic_pass=False).prefetch(3)
iterator = iter(dataset)

train_checkpointer = common.Checkpointer(ckpt_dir=CHECKPOINT_DIR, max_to_keep=1,
                                         agent=agent, policy=agent.policy,
                                         replay_buffer=replay_buffer,
                                         global_step=global_step)


# train the agent
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)


def train_one_iteration():
  # Collect a few steps using collect_policy and save to the replay buffer.
  driver.run()

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  # print('#' * 80)
Beispiel #28
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=1000,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(100, ),
        value_net_fc_layers=(100, ),
        use_value_network=False,
        # Params for collect
        collect_episodes_per_iteration=2,
        replay_buffer_capacity=2000,
        # Params for train
        learning_rate=1e-3,
        gamma=0.9,
        gradient_clipping=None,
        normalize_returns=True,
        value_estimation_loss_coef=0.2,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=100,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=100,
        rb_checkpoint_interval=200,
        log_interval=100,
        summary_interval=100,
        summaries_flush_secs=1,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for Reinforce."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_py_env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

        # TODO(b/127870767): Handle distributions without gin.
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers)

        if use_value_network:
            value_net = value_network.ValueNetwork(
                tf_env.time_step_spec().observation,
                fc_layer_params=value_net_fc_layers)

        tf_agent = reinforce_agent.ReinforceAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            value_network=value_net if use_value_network else None,
            value_estimation_loss_coef=value_estimation_loss_coef,
            gamma=gamma,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            normalize_returns=normalize_returns,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        experience = replay_buffer.gather_all()
        train_op = tf_agent.train(experience)
        clear_rb_op = replay_buffer.clear()

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            # TODO(b/126239733): Remove once Periodically can be saved.
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            # Compute evaluation metrics.
            global_step_call = sess.make_callable(global_step)
            global_step_val = global_step_call()
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            clear_rb_call = sess.make_callable(clear_rb_op)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                total_loss, _ = train_step_call()
                clear_rb_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
dataset = _replay_buffer.as_dataset(num_parallel_calls=30,
                                    sample_batch_size=_batch_size,
                                    num_steps=2).prefetch(30)

_agent.train = common.function(_agent.train)

_agent.train_step_counter.assign(0)
print('initial collect...')
avg_return = compute_avg_return(_eval_env, _agent.policy, _num_eval_episodes)
returns = [avg_return]
iterator = iter(dataset)

train_checkpointer = common.Checkpointer(ckpt_dir=_checkpoint_policy_dir,
                                         max_to_keep=1,
                                         agent=_agent,
                                         policy=_agent.policy,
                                         replay_buffer=_replay_buffer,
                                         global_step=_train_step_counter)

tf_policy_saver = policy_saver.PolicySaver(_agent.policy)

restore_network = True

if restore_network:
    train_checkpointer.initialize_or_restore()

#_train_env.pyenv._envs[0].set_rendering(enabled=False)

while True:
    print('Collecting...')
    for _ in tqdm(range(_num_train_episodes)):
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=3000000,
    actor_fc_layers=(),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    dual_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=50000,
    policy_checkpoint_interval=50000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None,
    latent_dim=10,
    log_prob_reward_scale=0.0,
    predictor_updates_encoder=False,
    predict_prior=True,
    use_recurrent_actor=False,
    rnn_sequence_length=20,
    clip_max_stddev=10.0,
    clip_min_stddev=0.1,
    clip_mean=30.0,
    predictor_num_layers=2,
    use_identity_encoder=False,
    identity_encoder_single_stddev=False,
    kl_constraint=1.0,
    eval_dropout=(),
    use_residual_predictor=True,
    gym_kwargs=None,
    predict_prior_std=True,
    random_seed=0,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    if use_recurrent_actor:
        batch_size = batch_size // rnn_sequence_length
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):

        _build_env = functools.partial(
            suite_gym.load,
            environment_name=env_name,  # pylint: disable=invalid-name
            gym_env_wrappers=(),
            gym_kwargs=gym_kwargs)

        tf_env = tf_py_environment.TFPyEnvironment(_build_env())
        eval_vec = []  # (name, env, metrics)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes)
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(_build_env())
        name = ''
        eval_vec.append((name, eval_tf_env, eval_metrics))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()
        if latent_dim == 'obs':
            latent_dim = observation_spec.shape[0]

        def _activation(t):
            t1, t2 = tf.split(t, 2, axis=1)
            low = -np.inf if clip_mean is None else -clip_mean
            high = np.inf if clip_mean is None else clip_mean
            t1 = rpc_utils.squash_to_range(t1, low, high)

            if clip_min_stddev is None:
                low = -np.inf
            else:
                low = tf.math.log(tf.exp(clip_min_stddev) - 1.0)
            if clip_max_stddev is None:
                high = np.inf
            else:
                high = tf.math.log(tf.exp(clip_max_stddev) - 1.0)
            t2 = rpc_utils.squash_to_range(t2, low, high)
            return tf.concat([t1, t2], axis=1)

        if use_identity_encoder:
            assert latent_dim == observation_spec.shape[0]
            obs_input = tf.keras.layers.Input(observation_spec.shape)
            zeros = 0.0 * obs_input[:, :1]
            stddev_dim = 1 if identity_encoder_single_stddev else latent_dim
            pre_stddev = tf.keras.layers.Dense(stddev_dim,
                                               activation=None)(zeros)
            ones = zeros + tf.ones((1, latent_dim))
            pre_stddev = pre_stddev * ones  # Multiply to broadcast to latent_dim.
            pre_mean_stddev = tf.concat([obs_input, pre_stddev], axis=1)
            output = tfp.layers.IndependentNormal(latent_dim)(pre_mean_stddev)
            encoder_net = tf.keras.Model(inputs=obs_input, outputs=output)
        else:
            encoder_net = tf.keras.Sequential([
                tf.keras.layers.Dense(256, activation='relu'),
                tf.keras.layers.Dense(256, activation='relu'),
                tf.keras.layers.Dense(
                    tfp.layers.IndependentNormal.params_size(latent_dim),
                    activation=_activation,
                    kernel_initializer='glorot_uniform'),
                tfp.layers.IndependentNormal(latent_dim),
            ])

        # Build the predictor net
        obs_input = tf.keras.layers.Input(observation_spec.shape)
        action_input = tf.keras.layers.Input(action_spec.shape)

        class ConstantIndependentNormal(tfp.layers.IndependentNormal):
            """A keras layer that always returns N(0, 1) distribution."""
            def call(self, inputs):
                loc_scale = tf.concat([
                    tf.zeros((latent_dim, )),
                    tf.fill((latent_dim, ), tf.math.log(tf.exp(1.0) - 1))
                ],
                                      axis=0)
                # Multiple by [B x 1] tensor to broadcast batch dimension.
                loc_scale = loc_scale * tf.ones_like(inputs[:, :1])
                return super(ConstantIndependentNormal, self).call(loc_scale)

        if predict_prior:
            z = encoder_net(obs_input)
            if not predictor_updates_encoder:
                z = tf.stop_gradient(z)
            za = tf.concat([z, action_input], axis=1)
            if use_residual_predictor:
                za_input = tf.keras.layers.Input(za.shape[1])
                loc_scale = tf.keras.Sequential(
                    predictor_num_layers *
                    [tf.keras.layers.Dense(256, activation='relu')] + [  # pylint: disable=line-too-long
                        tf.keras.layers.Dense(tfp.layers.IndependentNormal.
                                              params_size(latent_dim),
                                              activation=_activation,
                                              kernel_initializer='zeros'),
                    ])(za_input)
                if predict_prior_std:
                    combined_loc_scale = tf.concat([
                        loc_scale[:, :latent_dim] + za_input[:, :latent_dim],
                        loc_scale[:, latent_dim:]
                    ],
                                                   axis=1)
                else:
                    # Note that softplus(log(e - 1)) = 1.
                    combined_loc_scale = tf.concat([
                        loc_scale[:, :latent_dim] + za_input[:, :latent_dim],
                        tf.math.log(np.e - 1) *
                        tf.ones_like(loc_scale[:, latent_dim:])
                    ],
                                                   axis=1)
                dist = tfp.layers.IndependentNormal(latent_dim)(
                    combined_loc_scale)
                output = tf.keras.Model(inputs=za_input, outputs=dist)(za)
            else:
                assert predict_prior_std
                output = tf.keras.Sequential(
                    predictor_num_layers *
                    [tf.keras.layers.Dense(256, activation='relu')] +  # pylint: disable=line-too-long
                    [
                        tf.keras.layers.Dense(tfp.layers.IndependentNormal.
                                              params_size(latent_dim),
                                              activation=_activation,
                                              kernel_initializer='zeros'),
                        tfp.layers.IndependentNormal(latent_dim),
                    ])(za)
        else:
            # scale is chosen by inverting the softplus function to equal 1.
            if len(obs_input.shape) > 2:
                input_reshaped = tf.reshape(
                    obs_input,
                    [-1, tf.math.reduce_prod(obs_input.shape[1:])])
                #  Multiply by [B x 1] tensor to broadcast batch dimension.
                za = tf.zeros(latent_dim + action_spec.shape[0], ) * tf.ones_like(input_reshaped[:, :1])  # pylint: disable=line-too-long
            else:
                #  Multiple by [B x 1] tensor to broadcast batch dimension.
                za = tf.zeros(latent_dim + action_spec.shape[0], ) * tf.ones_like(obs_input[:, :1])  # pylint: disable=line-too-long
            output = tf.keras.Sequential([
                ConstantIndependentNormal(latent_dim),
            ])(za)
        predictor_net = tf.keras.Model(inputs=(obs_input, action_input),
                                       outputs=output)
        if use_recurrent_actor:
            ActorClass = rpc_utils.RecurrentActorNet  # pylint: disable=invalid-name
        else:
            ActorClass = rpc_utils.ActorNet  # pylint: disable=invalid-name
        actor_net = ActorClass(input_tensor_spec=observation_spec,
                               output_tensor_spec=action_spec,
                               encoder=encoder_net,
                               predictor=predictor_net,
                               fc_layers=actor_fc_layers)

        critic_net = rpc_utils.CriticNet(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')
        critic_net_2 = None
        target_critic_net_1 = None
        target_critic_net_2 = None

        tf_agent = rpc_agent.RpAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            critic_network_2=critic_net_2,
            target_critic_network=target_critic_net_1,
            target_critic_network_2=target_critic_net_2,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        dual_optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=dual_learning_rate)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]
        kl_metric = rpc_utils.AverageKLMetric(encoder=encoder_net,
                                              predictor=predictor_net,
                                              batch_size=tf_env.batch_size)
        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        checkpoint_items = {
            'ckpt_dir': train_dir,
            'agent': tf_agent,
            'global_step': global_step,
            'metrics': metric_utils.MetricsGroup(train_metrics,
                                                 'train_metrics'),
            'dual_optimizer': dual_optimizer,
        }
        train_checkpointer = common.Checkpointer(**checkpoint_items)

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps,
            transition_observers=[kl_metric])

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration,
            transition_observers=[kl_metric])

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        if replay_buffer.num_frames() == 0:
            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps '
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        for name, eval_tf_env, eval_metrics in eval_vec:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics-%s' % name,
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics, prefix=name)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0
        train_time_acc = 0
        env_time_acc = 0

        if use_recurrent_actor:  # default from sac/train_eval_rnn.py
            num_steps = rnn_sequence_length + 1

            def _filter_invalid_transition(trajectories, unused_arg1):
                return tf.reduce_all(~trajectories.is_boundary()[:-1])

            tf_agent._as_transition = data_converter.AsTransition(  # pylint: disable=protected-access
                tf_agent.data_context,
                squeeze_time_dim=False)
        else:
            num_steps = 2

            def _filter_invalid_transition(trajectories, unused_arg1):
                return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size,
            num_steps=num_steps).unbatch().filter(_filter_invalid_transition)

        dataset = dataset.batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        @tf.function
        def train_step():
            experience, _ = next(iterator)

            prior = predictor_net(
                (experience.observation[:, 0], experience.action[:, 0]),
                training=False)
            z_next = encoder_net(experience.observation[:, 1], training=False)
            # predictor_kl is a vector of size batch_size.
            predictor_kl = tfp.distributions.kl_divergence(z_next, prior)

            with tf.GradientTape() as tape:
                tape.watch(actor_net._log_kl_coefficient)  # pylint: disable=protected-access
                dual_loss = -1.0 * actor_net._log_kl_coefficient * (  # pylint: disable=protected-access
                    tf.stop_gradient(tf.reduce_mean(predictor_kl)) -
                    kl_constraint)
            dual_grads = tape.gradient(dual_loss,
                                       [actor_net._log_kl_coefficient])  # pylint: disable=protected-access
            grads_and_vars = list(
                zip(dual_grads, [actor_net._log_kl_coefficient]))  # pylint: disable=protected-access
            dual_optimizer.apply_gradients(grads_and_vars)

            # Clip the dual variable so exp(log_kl_coef) <= 1e6.
            log_kl_coef = tf.clip_by_value(
                actor_net._log_kl_coefficient,  # pylint: disable=protected-access
                -1.0 * np.log(1e6),
                np.log(1e6))
            actor_net._log_kl_coefficient.assign(log_kl_coef)  # pylint: disable=protected-access

            with tf.name_scope('dual_loss'):
                tf.compat.v2.summary.scalar(name='dual_loss',
                                            data=tf.reduce_mean(dual_loss),
                                            step=global_step)
                tf.compat.v2.summary.scalar(
                    name='log_kl_coefficient',
                    data=actor_net._log_kl_coefficient,  # pylint: disable=protected-access
                    step=global_step)

            z_entropy = z_next.entropy()
            log_prob = prior.log_prob(z_next.sample())
            with tf.name_scope('rp-metrics'):
                common.generate_tensor_summaries('predictor_kl', predictor_kl,
                                                 global_step)
                common.generate_tensor_summaries('z_entropy', z_entropy,
                                                 global_step)
                common.generate_tensor_summaries('log_prob', log_prob,
                                                 global_step)
                common.generate_tensor_summaries('z_mean', z_next.mean(),
                                                 global_step)
                common.generate_tensor_summaries('z_stddev', z_next.stddev(),
                                                 global_step)
                common.generate_tensor_summaries('prior_mean', prior.mean(),
                                                 global_step)
                common.generate_tensor_summaries('prior_stddev',
                                                 prior.stddev(), global_step)

            if log_prob_reward_scale == 'auto':
                coef = tf.stop_gradient(tf.exp(actor_net._log_kl_coefficient))  # pylint: disable=protected-access
            else:
                coef = log_prob_reward_scale
            tf.debugging.check_numerics(tf.reduce_mean(predictor_kl),
                                        'predictor_kl is inf or nan.')
            tf.debugging.check_numerics(coef, 'coef is inf or nan.')
            new_reward = experience.reward - coef * predictor_kl[:, None]

            experience = experience._replace(reward=new_reward)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        # Save the hyperparameters
        operative_filename = os.path.join(root_dir, 'operative.gin')
        with tf.compat.v1.gfile.Open(operative_filename, 'w') as f:
            f.write(gin.operative_config_str())
            print(gin.operative_config_str())

        global_step_val = global_step.numpy()
        while global_step_val < num_iterations:
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            env_time_acc += time.time() - start_time
            train_start_time = time.time()
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            train_time_acc += time.time() - train_start_time
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                train_steps_per_sec = (global_step_val -
                                       timed_at_step) / train_time_acc
                logging.info('Train: %.3f steps/sec', train_steps_per_sec)
                tf.compat.v2.summary.scalar(name='train_steps_per_sec',
                                            data=train_steps_per_sec,
                                            step=global_step)
                env_steps_per_sec = (global_step_val -
                                     timed_at_step) / env_time_acc
                logging.info('Env: %.3f steps/sec', env_steps_per_sec)
                tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                            data=env_steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0
                train_time_acc = 0
                env_time_acc = 0

            for train_metric in train_metrics + [kl_metric]:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                start_time = time.time()
                for name, eval_tf_env, eval_metrics in eval_vec:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='Metrics-%s' % name,
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics, prefix=name)
                logging.info('Evaluation: %d min',
                             (time.time() - start_time) / 60)
                for prob_dropout in eval_dropout:
                    rpc_utils.eval_dropout_fn(eval_tf_env,
                                              actor_net,
                                              global_step,
                                              prob_dropout=prob_dropout)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)