def train_eval( root_dir, # Dataset params env_name, data_dir=None, load_pretrained=False, pretrained_model_dir=None, img_pad=4, frame_shape=(84, 84, 3), frame_stack=3, num_augmentations=2, # K and M in DrQ # Training params contrastive_loss_weight=1.0, contrastive_loss_temperature=0.5, image_encoder_representation=True, initial_collect_steps=1000, num_train_steps=3000000, actor_fc_layers=(1024, 1024), critic_joint_fc_layers=(1024, 1024), # Agent params batch_size=256, actor_learning_rate=1e-3, critic_learning_rate=1e-3, alpha_learning_rate=1e-3, encoder_learning_rate=1e-3, actor_update_freq=2, gamma=0.99, target_update_tau=0.01, target_update_period=2, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others checkpoint_interval=10000, policy_save_interval=5000, eval_interval=10000, summary_interval=250, debug_summaries=False, eval_episodes_per_run=10, summarize_grads_and_vars=False): """Trains and evaluates SAC.""" collect_env = env_utils.load_dm_env_for_training(env_name, frame_shape, frame_stack=frame_stack) eval_env = env_utils.load_dm_env_for_eval(env_name, frame_shape, frame_stack=frame_stack) logging.info('Data directory: %s', data_dir) logging.info('Num train steps: %d', num_train_steps) logging.info('Contrastive loss coeff: %.2f', contrastive_loss_weight) logging.info('Contrastive loss temperature: %.4f', contrastive_loss_temperature) logging.info('load_pretrained: %s', 'yes' if load_pretrained else 'no') logging.info('encoder representation: %s', 'yes' if image_encoder_representation else 'no') load_episode_data = (contrastive_loss_weight > 0) observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) train_step = train_utils.create_train_step() image_encoder = networks.ImageEncoder(observation_tensor_spec) actor_net = model_utils.Actor( observation_tensor_spec, action_tensor_spec, image_encoder=image_encoder, fc_layers=actor_fc_layers, image_encoder_representation=image_encoder_representation) critic_net = networks.Critic((observation_tensor_spec, action_tensor_spec), image_encoder=image_encoder, joint_fc_layers=critic_joint_fc_layers) critic_net_2 = networks.Critic( (observation_tensor_spec, action_tensor_spec), image_encoder=image_encoder, joint_fc_layers=critic_joint_fc_layers) target_image_encoder = networks.ImageEncoder(observation_tensor_spec) target_critic_net_1 = networks.Critic( (observation_tensor_spec, action_tensor_spec), image_encoder=target_image_encoder) target_critic_net_2 = networks.Critic( (observation_tensor_spec, action_tensor_spec), image_encoder=target_image_encoder) agent = pse_drq_agent.DrQSacModifiedAgent( time_step_tensor_spec, action_tensor_spec, actor_network=actor_net, critic_network=critic_net, critic_network_2=critic_net_2, target_critic_network=target_critic_net_1, target_critic_network_2=target_critic_net_2, actor_update_frequency=actor_update_freq, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_learning_rate), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_learning_rate), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_learning_rate), contrastive_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=encoder_learning_rate), contrastive_loss_weight=contrastive_loss_weight, contrastive_loss_temperature=contrastive_loss_temperature, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, use_log_alpha_in_alpha_loss=False, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step, num_augmentations=num_augmentations) agent.initialize() # Setup the replay buffer. reverb_replay, rb_observer = ( replay_buffer_utils.get_reverb_buffer_and_observer( agent.collect_data_spec, sequence_length=2, replay_capacity=replay_capacity, port=reverb_port)) # pylint: disable=g-long-lambda if num_augmentations == 0: image_aug = lambda traj, meta: (dict( experience=traj, augmented_obs=[], augmented_next_obs=[]), meta) else: image_aug = lambda traj, meta: pse_drq_agent.image_aug( traj, meta, img_pad, num_augmentations) augmented_dataset = reverb_replay.as_dataset(sample_batch_size=batch_size, num_steps=2).unbatch().map( image_aug, num_parallel_calls=3) augmented_iterator = iter(augmented_dataset) trajs = augmented_dataset.batch(batch_size).prefetch(50) if load_episode_data: # Load full episodes and zip them episodes = dataset_utils.load_episodes( os.path.join(data_dir, 'episodes2'), img_pad) episode_iterator = iter(episodes) dataset = tf.data.Dataset.zip((trajs, episodes)).prefetch(10) else: dataset = trajs experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) learning_triggers = [ triggers.PolicySavedModelTrigger(saved_model_dir, agent, train_step, interval=policy_save_interval), triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval), ] agent_learner = model_utils.Learner( root_dir, train_step, agent, experience_dataset_fn=experience_dataset_fn, triggers=learning_triggers, checkpoint_interval=checkpoint_interval, summary_interval=summary_interval, load_episode_data=load_episode_data, use_kwargs_in_agent_train=True, # Turn off the initialization of the optimizer variables since, the agent # expects different batching for the `training_data_spec` and # `train_argspec` which can't be handled in general by the initialization # logic in the learner. run_optimizer_variable_init=False) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. train_dir = os.path.join(root_dir, learner.TRAIN_DIR) # Code for loading pretrained policy. if load_pretrained: # Note that num_train_steps is same as the max_train_step we want to # load the pretrained policy for our experiments pretrained_policy = model_utils.load_pretrained_policy( pretrained_model_dir, num_train_steps) initial_collect_policy = pretrained_policy agent.policy.update_partial(pretrained_policy) agent.collect_policy.update_partial(pretrained_policy) logging.info('Restored pretrained policy.') else: initial_collect_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor(collect_env, initial_collect_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor(collect_env, collect_policy, train_step, steps_per_run=1, observers=[rb_observer], metrics=actor.collect_metrics(buffer_size=10), summary_dir=train_dir, summary_interval=summary_interval, name='CollectActor') # If restarting with train_step > 0, the replay buffer will be empty # except for random experience. Populate the buffer with some on-policy # experience. if load_pretrained or (agent_learner.train_step_numpy > 0): for _ in range(batch_size * 50): collect_actor.run() tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor(eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes_per_run, metrics=actor.eval_metrics(buffer_size=10), summary_dir=os.path.join(root_dir, 'eval'), summary_interval=-1, name='EvalTrainActor') if eval_interval: logging.info('Evaluating.') img_summary( next(augmented_iterator)[0], eval_actor.summary_writer, train_step) if load_episode_data: contrastive_img_summary(next(episode_iterator), agent, eval_actor.summary_writer, train_step) eval_actor.run_and_log() logging.info('Saving operative gin config file.') gin_path = os.path.join(train_dir, 'train_operative_gin_config.txt') with tf.io.gfile.GFile(gin_path, mode='w') as f: f.write(gin.operative_config_str()) logging.info('Training Staring at: %r', train_step.numpy()) while train_step < num_train_steps: collect_actor.run() agent_learner.run(iterations=1) if (not eval_interval) and (train_step % 10000 == 0): img_summary( next(augmented_iterator)[0], agent_learner.train_summary_writer, train_step) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') img_summary( next(augmented_iterator)[0], eval_actor.summary_writer, train_step) if load_episode_data: contrastive_img_summary(next(episode_iterator), agent, eval_actor.summary_writer, train_step) eval_actor.run_and_log()
def test_actor_updated_on_second_train(self): self.setup_agent() experience_spec = self._agent.collect_data_spec def _bound_specs(s): if s.dtype != tf.float32: return s return tensor_spec.BoundedTensorSpec(dtype=s.dtype, shape=s.shape, minimum=-1, maximum=1) experience_spec = tf.nest.map_structure(_bound_specs, experience_spec) sample_experience_1 = tensor_spec.sample_spec_nest(experience_spec, outer_dims=(2, )) sample_experience_2 = tensor_spec.sample_spec_nest(experience_spec, outer_dims=(2, )) augmented_sample_1 = pse_drq_agent.image_aug(sample_experience_1, (), img_pad=4, num_augmentations=2) augmented_sample_2 = pse_drq_agent.image_aug(sample_experience_2, (), img_pad=4, num_augmentations=2) augmented_sample = tf.nest.map_structure( # pylint: disable=g-long-lambda lambda t1, t2: tf.concat( [tf.expand_dims(t1, 0), tf.expand_dims(t2, 0)], axis=0), augmented_sample_1, augmented_sample_2)[0] augmented_experience = augmented_sample.pop('experience') sample_train_kwargs = augmented_sample self._agent.initialize() encoder_variables = self._image_encoder.variables num_encoder_variables = len(encoder_variables) # Evaluate here to get a copy of the values. actor_variables = self.evaluate([v for v in self._actor_net.variables]) self._agent.train(augmented_experience, **sample_train_kwargs) updated_actor_variables = self.evaluate( [v for v in self._actor_net.variables]) for v1, v2 in zip(actor_variables[num_encoder_variables:], updated_actor_variables[num_encoder_variables:]): np.testing.assert_equal(v1, v2) # Second call now variables should differ. self._agent.train(augmented_experience, **sample_train_kwargs) updated_actor_variables = self.evaluate( [v for v in self._actor_net.variables]) with self.assertRaises(AssertionError): for v1, v2 in zip(actor_variables[num_encoder_variables:], updated_actor_variables[num_encoder_variables:]): np.testing.assert_equal(v1, v2)