Esempio n. 1
0
    def testAverageOneEpisode(self, metric_class, expected_result):
        metric = metric_class()

        metric(trajectory.boundary((), (), (), 0., 1.))
        metric(trajectory.mid((), (), (), 1., 1.))
        metric(trajectory.mid((), (), (), 2., 1.))
        metric(trajectory.last((), (), (), 3., 0.))
        self.assertEqual(expected_result, metric.result())
Esempio n. 2
0
    def testBatchSizeProvided(self, metric_class, expected_result):
        metric = metric_class(batch_size=2)

        metric(
            nest_utils.stack_nested_arrays([
                trajectory.boundary((), (), (), 0., 1.),
                trajectory.boundary((), (), (), 0., 1.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.first((), (), (), 1., 1.),
                trajectory.first((), (), (), 1., 1.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.mid((), (), (), 2., 1.),
                trajectory.last((), (), (), 3., 0.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.last((), (), (), 3., 0.),
                trajectory.boundary((), (), (), 0., 1.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.boundary((), (), (), 0., 1.),
                trajectory.first((), (), (), 1., 1.)
            ]))
        self.assertEqual(metric.result(), expected_result)
Esempio n. 3
0
    def testAverageOneEpisodeWithReset(self, metric_class, expected_result):
        metric = metric_class()

        metric(trajectory.first((), (), (), 0., 1.))
        metric(trajectory.mid((), (), (), 1., 1.))
        metric(trajectory.mid((), (), (), 2., 1.))
        # The episode is reset.
        #
        # This could happen when using the dynamic_episode_driver with
        # parallel_py_environment. When the parallel episodes are of different
        # lengths and num_episodes is reached, some episodes would be left in "MID".
        # When the driver runs again, all environments are reset at the beginning
        # of the tf.while_loop and the unfinished episodes would get "FIRST" without
        # seeing "LAST".
        metric(trajectory.first((), (), (), 3., 1.))
        metric(trajectory.last((), (), (), 4., 1.))
        self.assertEqual(expected_result, metric.result())
 def testMidArrays(self):
     observation = ()
     action = ()
     policy_info = ()
     reward = np.array([1.0, 1.0, 2.0])
     discount = np.array([1.0, 1.0, 1.0])
     traj = trajectory.mid(observation, action, policy_info, reward,
                           discount)
     self.assertFalse(tf.is_tensor(traj.step_type))
     self.assertAllEqual(traj.step_type, [ts.StepType.MID] * 3)
     self.assertAllEqual(traj.next_step_type, [ts.StepType.MID] * 3)
 def testMidTensors(self):
     observation = ()
     action = ()
     policy_info = ()
     reward = tf.constant([1.0, 1.0, 2.0])
     discount = tf.constant([1.0, 1.0, 1.0])
     traj = trajectory.mid(observation, action, policy_info, reward,
                           discount)
     self.assertTrue(tf.is_tensor(traj.step_type))
     traj_val = self.evaluate(traj)
     self.assertAllEqual(traj_val.step_type, [ts.StepType.MID] * 3)
     self.assertAllEqual(traj_val.next_step_type, [ts.StepType.MID] * 3)
Esempio n. 6
0
    def testSaveRestore(self):
        metrics = [
            py_metrics.AverageReturnMetric(),
            py_metrics.AverageEpisodeLengthMetric(),
            py_metrics.EnvironmentSteps(),
            py_metrics.NumberOfEpisodes()
        ]

        for metric in metrics:
            metric(trajectory.boundary((), (), (), 0., 1.))
            metric(trajectory.mid((), (), (), 1., 1.))
            metric(trajectory.mid((), (), (), 2., 1.))
            metric(trajectory.last((), (), (), 3., 0.))

        checkpoint = tf.train.Checkpoint(**{m.name: m for m in metrics})
        prefix = self.get_temp_dir() + '/ckpt'
        save_path = checkpoint.save(prefix)
        for metric in metrics:
            metric.reset()
            self.assertEqual(0, metric.result())
        checkpoint.restore(save_path).assert_consumed()
        for metric in metrics:
            self.assertGreater(metric.result(), 0)
Esempio n. 7
0
    def testAverageTwoEpisode(self, metric_class, expected_result):
        metric = metric_class()

        metric(trajectory.boundary((), (), (), 0., 1.))
        metric(trajectory.first((), (), (), 1., 1.))
        metric(trajectory.mid((), (), (), 2., 1.))
        metric(trajectory.last((), (), (), 3., 0.))
        metric(trajectory.boundary((), (), (), 0., 1.))

        # TODO(kbanoop): Add optional next_step_type arg to trajectory.first. Or
        # implement trajectory.first_last().
        metric(
            trajectory.Trajectory(ts.StepType.FIRST, (), (), (),
                                  ts.StepType.LAST, -6., 1.))

        self.assertEqual(expected_result, metric.result())
Esempio n. 8
0
def train_eval(
        root_dir,
        dataset_path,
        env_name,
        # Training params
        tpu=False,
        use_gpu=False,
        num_gradient_updates=1000000,
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256, 256),
        # Agent params
        batch_size=256,
        bc_steps=0,
        actor_learning_rate=3e-5,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        reward_scale_factor=1.0,
        cql_alpha_learning_rate=3e-4,
        cql_alpha=5.0,
        cql_tau=10.0,
        num_cql_samples=10,
        reward_noise_variance=0.0,
        include_critic_entropy_term=False,
        use_lagrange_cql_alpha=True,
        log_cql_alpha_clipping=None,
        softmax_temperature=1.0,
        # Data params
        reward_shift=0.0,
        action_clipping=None,
        use_trajectories=False,
        data_shuffle_buffer_size_per_record=1,
        data_shuffle_buffer_size=100,
        data_num_shards=1,
        data_block_length=10,
        data_parallel_reads=None,
        data_parallel_calls=10,
        data_prefetch=10,
        data_cycle_length=10,
        # Others
        policy_save_interval=10000,
        eval_interval=10000,
        summary_interval=1000,
        learner_iterations_per_call=1,
        eval_episodes=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        seed=None):
    """Trains and evaluates CQL-SAC."""
    logging.info('Training CQL-SAC on: %s', env_name)
    tf.random.set_seed(seed)
    np.random.seed(seed)

    # Load environment.
    env = load_d4rl(env_name)
    tf_env = tf_py_environment.TFPyEnvironment(env)
    strategy = strategy_utils.get_strategy(tpu, use_gpu)

    if not dataset_path.endswith('.tfrecord'):
        dataset_path = os.path.join(dataset_path, env_name,
                                    '%s*.tfrecord' % env_name)
    logging.info('Loading dataset from %s', dataset_path)
    dataset_paths = tf.io.gfile.glob(dataset_path)

    # Create dataset.
    with strategy.scope():
        dataset = create_tf_record_dataset(
            dataset_paths,
            batch_size,
            shuffle_buffer_size_per_record=data_shuffle_buffer_size_per_record,
            shuffle_buffer_size=data_shuffle_buffer_size,
            num_shards=data_num_shards,
            cycle_length=data_cycle_length,
            block_length=data_block_length,
            num_parallel_reads=data_parallel_reads,
            num_parallel_calls=data_parallel_calls,
            num_prefetch=data_prefetch,
            strategy=strategy,
            reward_shift=reward_shift,
            action_clipping=action_clipping,
            use_trajectories=use_trajectories)

    # Create agent.
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()
    with strategy.scope():
        train_step = train_utils.create_train_step()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)

        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        agent = cql_sac_agent.CqlSacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.keras.optimizers.Adam(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.keras.optimizers.Adam(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.keras.optimizers.Adam(
                learning_rate=alpha_learning_rate),
            cql_alpha=cql_alpha,
            num_cql_samples=num_cql_samples,
            include_critic_entropy_term=include_critic_entropy_term,
            use_lagrange_cql_alpha=use_lagrange_cql_alpha,
            cql_alpha_learning_rate=cql_alpha_learning_rate,
            target_update_tau=5e-3,
            target_update_period=1,
            random_seed=seed,
            cql_tau=cql_tau,
            reward_noise_variance=reward_noise_variance,
            num_bc_steps=bc_steps,
            td_errors_loss_fn=tf.math.squared_difference,
            gamma=0.99,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=None,
            log_cql_alpha_clipping=log_cql_alpha_clipping,
            softmax_temperature=softmax_temperature,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step)
        agent.initialize()

    # Create learner.
    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    collect_env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(saved_model_dir,
                                         agent,
                                         train_step,
                                         interval=policy_save_interval,
                                         metadata_metrics={
                                             triggers.ENV_STEP_METADATA_KEY:
                                             collect_env_step_metric
                                         }),
        triggers.StepPerSecondLogTrigger(train_step, interval=100)
    ]
    cql_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn=lambda: dataset,
                                  triggers=learning_triggers,
                                  summary_interval=summary_interval,
                                  strategy=strategy)

    # Create actor for evaluation.
    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)
    eval_actor = actor.Actor(env,
                             eval_greedy_policy,
                             train_step,
                             metrics=actor.eval_metrics(eval_episodes),
                             summary_dir=os.path.join(root_dir, 'eval'),
                             episodes_per_run=eval_episodes)

    # Run.
    dummy_trajectory = trajectory.mid((), (), (), 0., 1.)
    num_learner_iterations = int(num_gradient_updates /
                                 learner_iterations_per_call)
    for _ in range(num_learner_iterations):
        # Mimic collecting environment steps since we loaded a static dataset.
        for _ in range(learner_iterations_per_call):
            collect_env_step_metric(dummy_trajectory)

        cql_learner.run(iterations=learner_iterations_per_call)
        if eval_interval and train_step.numpy() % eval_interval == 0:
            eval_actor.run_and_log()