Exemple #1
0
def train_eval(
        root_dir,
        # Dataset params
        env_name,
        data_dir=None,
        load_pretrained=False,
        pretrained_model_dir=None,
        img_pad=4,
        frame_shape=(84, 84, 3),
        frame_stack=3,
        num_augmentations=2,  # K and M in DrQ
        # Training params
    contrastive_loss_weight=1.0,
        contrastive_loss_temperature=0.5,
        image_encoder_representation=True,
        initial_collect_steps=1000,
        num_train_steps=3000000,
        actor_fc_layers=(1024, 1024),
        critic_joint_fc_layers=(1024, 1024),
        # Agent params
        batch_size=256,
        actor_learning_rate=1e-3,
        critic_learning_rate=1e-3,
        alpha_learning_rate=1e-3,
        encoder_learning_rate=1e-3,
        actor_update_freq=2,
        gamma=0.99,
        target_update_tau=0.01,
        target_update_period=2,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        checkpoint_interval=10000,
        policy_save_interval=5000,
        eval_interval=10000,
        summary_interval=250,
        debug_summaries=False,
        eval_episodes_per_run=10,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    collect_env = env_utils.load_dm_env_for_training(env_name,
                                                     frame_shape,
                                                     frame_stack=frame_stack)
    eval_env = env_utils.load_dm_env_for_eval(env_name,
                                              frame_shape,
                                              frame_stack=frame_stack)

    logging.info('Data directory: %s', data_dir)
    logging.info('Num train steps: %d', num_train_steps)
    logging.info('Contrastive loss coeff: %.2f', contrastive_loss_weight)
    logging.info('Contrastive loss temperature: %.4f',
                 contrastive_loss_temperature)
    logging.info('load_pretrained: %s', 'yes' if load_pretrained else 'no')
    logging.info('encoder representation: %s',
                 'yes' if image_encoder_representation else 'no')

    load_episode_data = (contrastive_loss_weight > 0)
    observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()
    image_encoder = networks.ImageEncoder(observation_tensor_spec)

    actor_net = model_utils.Actor(
        observation_tensor_spec,
        action_tensor_spec,
        image_encoder=image_encoder,
        fc_layers=actor_fc_layers,
        image_encoder_representation=image_encoder_representation)

    critic_net = networks.Critic((observation_tensor_spec, action_tensor_spec),
                                 image_encoder=image_encoder,
                                 joint_fc_layers=critic_joint_fc_layers)
    critic_net_2 = networks.Critic(
        (observation_tensor_spec, action_tensor_spec),
        image_encoder=image_encoder,
        joint_fc_layers=critic_joint_fc_layers)

    target_image_encoder = networks.ImageEncoder(observation_tensor_spec)
    target_critic_net_1 = networks.Critic(
        (observation_tensor_spec, action_tensor_spec),
        image_encoder=target_image_encoder)
    target_critic_net_2 = networks.Critic(
        (observation_tensor_spec, action_tensor_spec),
        image_encoder=target_image_encoder)

    agent = pse_drq_agent.DrQSacModifiedAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        critic_network_2=critic_net_2,
        target_critic_network=target_critic_net_1,
        target_critic_network_2=target_critic_net_2,
        actor_update_frequency=actor_update_freq,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        contrastive_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=encoder_learning_rate),
        contrastive_loss_weight=contrastive_loss_weight,
        contrastive_loss_temperature=contrastive_loss_temperature,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        use_log_alpha_in_alpha_loss=False,
        gradient_clipping=None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step,
        num_augmentations=num_augmentations)
    agent.initialize()

    # Setup the replay buffer.
    reverb_replay, rb_observer = (
        replay_buffer_utils.get_reverb_buffer_and_observer(
            agent.collect_data_spec,
            sequence_length=2,
            replay_capacity=replay_capacity,
            port=reverb_port))

    # pylint: disable=g-long-lambda
    if num_augmentations == 0:
        image_aug = lambda traj, meta: (dict(
            experience=traj, augmented_obs=[], augmented_next_obs=[]), meta)
    else:
        image_aug = lambda traj, meta: pse_drq_agent.image_aug(
            traj, meta, img_pad, num_augmentations)
    augmented_dataset = reverb_replay.as_dataset(sample_batch_size=batch_size,
                                                 num_steps=2).unbatch().map(
                                                     image_aug,
                                                     num_parallel_calls=3)
    augmented_iterator = iter(augmented_dataset)

    trajs = augmented_dataset.batch(batch_size).prefetch(50)
    if load_episode_data:
        # Load full episodes and zip them
        episodes = dataset_utils.load_episodes(
            os.path.join(data_dir, 'episodes2'), img_pad)
        episode_iterator = iter(episodes)
        dataset = tf.data.Dataset.zip((trajs, episodes)).prefetch(10)
    else:
        dataset = trajs
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    learning_triggers = [
        triggers.PolicySavedModelTrigger(saved_model_dir,
                                         agent,
                                         train_step,
                                         interval=policy_save_interval),
        triggers.StepPerSecondLogTrigger(train_step,
                                         interval=summary_interval),
    ]

    agent_learner = model_utils.Learner(
        root_dir,
        train_step,
        agent,
        experience_dataset_fn=experience_dataset_fn,
        triggers=learning_triggers,
        checkpoint_interval=checkpoint_interval,
        summary_interval=summary_interval,
        load_episode_data=load_episode_data,
        use_kwargs_in_agent_train=True,
        # Turn off the initialization of the optimizer variables since, the agent
        # expects different batching for the `training_data_spec` and
        # `train_argspec` which can't be handled in general by the initialization
        # logic in the learner.
        run_optimizer_variable_init=False)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    train_dir = os.path.join(root_dir, learner.TRAIN_DIR)

    # Code for loading pretrained policy.
    if load_pretrained:
        # Note that num_train_steps is same as the max_train_step we want to
        # load the pretrained policy for our experiments
        pretrained_policy = model_utils.load_pretrained_policy(
            pretrained_model_dir, num_train_steps)
        initial_collect_policy = pretrained_policy

        agent.policy.update_partial(pretrained_policy)
        agent.collect_policy.update_partial(pretrained_policy)
        logging.info('Restored pretrained policy.')
    else:
        initial_collect_policy = random_py_policy.RandomPyPolicy(
            collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        initial_collect_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                observers=[rb_observer],
                                metrics=actor.collect_metrics(buffer_size=10),
                                summary_dir=train_dir,
                                summary_interval=summary_interval,
                                name='CollectActor')

    # If restarting with train_step > 0, the replay buffer will be empty
    # except for random experience. Populate the buffer with some on-policy
    # experience.
    if load_pretrained or (agent_learner.train_step_numpy > 0):
        for _ in range(batch_size * 50):
            collect_actor.run()

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(eval_env,
                             greedy_policy,
                             train_step,
                             episodes_per_run=eval_episodes_per_run,
                             metrics=actor.eval_metrics(buffer_size=10),
                             summary_dir=os.path.join(root_dir, 'eval'),
                             summary_interval=-1,
                             name='EvalTrainActor')

    if eval_interval:
        logging.info('Evaluating.')
        img_summary(
            next(augmented_iterator)[0], eval_actor.summary_writer, train_step)
        if load_episode_data:
            contrastive_img_summary(next(episode_iterator), agent,
                                    eval_actor.summary_writer, train_step)
        eval_actor.run_and_log()

    logging.info('Saving operative gin config file.')
    gin_path = os.path.join(train_dir, 'train_operative_gin_config.txt')
    with tf.io.gfile.GFile(gin_path, mode='w') as f:
        f.write(gin.operative_config_str())

    logging.info('Training Staring at: %r', train_step.numpy())
    while train_step < num_train_steps:
        collect_actor.run()
        agent_learner.run(iterations=1)
        if (not eval_interval) and (train_step % 10000 == 0):
            img_summary(
                next(augmented_iterator)[0],
                agent_learner.train_summary_writer, train_step)
        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            img_summary(
                next(augmented_iterator)[0], eval_actor.summary_writer,
                train_step)
            if load_episode_data:
                contrastive_img_summary(next(episode_iterator), agent,
                                        eval_actor.summary_writer, train_step)
            eval_actor.run_and_log()
Exemple #2
0
    def test_actor_updated_on_second_train(self):
        self.setup_agent()
        experience_spec = self._agent.collect_data_spec

        def _bound_specs(s):
            if s.dtype != tf.float32:
                return s
            return tensor_spec.BoundedTensorSpec(dtype=s.dtype,
                                                 shape=s.shape,
                                                 minimum=-1,
                                                 maximum=1)

        experience_spec = tf.nest.map_structure(_bound_specs, experience_spec)

        sample_experience_1 = tensor_spec.sample_spec_nest(experience_spec,
                                                           outer_dims=(2, ))
        sample_experience_2 = tensor_spec.sample_spec_nest(experience_spec,
                                                           outer_dims=(2, ))

        augmented_sample_1 = pse_drq_agent.image_aug(sample_experience_1, (),
                                                     img_pad=4,
                                                     num_augmentations=2)

        augmented_sample_2 = pse_drq_agent.image_aug(sample_experience_2, (),
                                                     img_pad=4,
                                                     num_augmentations=2)

        augmented_sample = tf.nest.map_structure(
            # pylint: disable=g-long-lambda
            lambda t1, t2: tf.concat(
                [tf.expand_dims(t1, 0),
                 tf.expand_dims(t2, 0)], axis=0),
            augmented_sample_1,
            augmented_sample_2)[0]

        augmented_experience = augmented_sample.pop('experience')
        sample_train_kwargs = augmented_sample

        self._agent.initialize()

        encoder_variables = self._image_encoder.variables
        num_encoder_variables = len(encoder_variables)

        # Evaluate here to get a copy of the values.
        actor_variables = self.evaluate([v for v in self._actor_net.variables])

        self._agent.train(augmented_experience, **sample_train_kwargs)

        updated_actor_variables = self.evaluate(
            [v for v in self._actor_net.variables])

        for v1, v2 in zip(actor_variables[num_encoder_variables:],
                          updated_actor_variables[num_encoder_variables:]):
            np.testing.assert_equal(v1, v2)

        # Second call now variables should differ.
        self._agent.train(augmented_experience, **sample_train_kwargs)

        updated_actor_variables = self.evaluate(
            [v for v in self._actor_net.variables])

        with self.assertRaises(AssertionError):
            for v1, v2 in zip(actor_variables[num_encoder_variables:],
                              updated_actor_variables[num_encoder_variables:]):
                np.testing.assert_equal(v1, v2)