def update(): """Update target network.""" critic_update_1 = common.soft_variables_update( self._critic_network_1.variables, self._target_critic_network_1.variables, tau, tau_non_trainable=1.0) critic_2_update_vars = common.deduped_network_variables( self._critic_network_2, self._critic_network_1) target_critic_2_update_vars = common.deduped_network_variables( self._target_critic_network_2, self._target_critic_network_1) critic_update_2 = common.soft_variables_update( critic_2_update_vars, target_critic_2_update_vars, tau, tau_non_trainable=1.0) return tf.group(critic_update_1, critic_update_2)
def testUpdateOnlyTargetVariables(self, tau): inputs = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) tf.contrib.layers.fully_connected(inputs, 2, scope='source') tf.contrib.layers.fully_connected(inputs, 2, scope='target') source_vars = tf.contrib.framework.get_model_variables('source') target_vars = tf.contrib.framework.get_model_variables('target') update_op = common.soft_variables_update(source_vars, target_vars, tau) self.evaluate(tf.compat.v1.global_variables_initializer()) v_s, v_t = self.evaluate([source_vars, target_vars]) self.evaluate(update_op) new_v_s, new_v_t = self.evaluate([source_vars, target_vars]) for i_v_s, i_v_t, n_v_s, n_v_t in zip(v_s, v_t, new_v_s, new_v_t): # Source variables don't change self.assertAllClose(n_v_s, i_v_s) # Target variables are updated self.assertAllClose(n_v_t, tau * i_v_s + (1 - tau) * i_v_t)
def _initialize(self): """Returns an op to initialize the agent. Copies weights from the Q networks to the target Q network. """ common.soft_variables_update(self._critic_network_1.variables, self._target_critic_network_1.variables, tau=1.0) common.soft_variables_update(self._critic_network_2.variables, self._target_critic_network_2.variables, tau=1.0) if self._critic_network_no_entropy_1 is not None: common.soft_variables_update( self._critic_network_no_entropy_1.variables, self._target_critic_network_no_entropy_1.variables, tau=1.0) common.soft_variables_update( self._critic_network_no_entropy_2.variables, self._target_critic_network_no_entropy_2.variables, tau=1.0)
def _initialize(self): common.soft_variables_update(self._q_network.variables, self._target_q_network.variables, tau=1.0) if self._enable_td3: common.soft_variables_update( self._q_network.variables, self._target_q_network_delayed.variables, tau=1.0) common.soft_variables_update( self._q_network.variables, self._target_q_network_delayed_2.variables, tau=1.0)
def testUpdateOnlyTargetVariables(self, tau): with tf.Graph().as_default() as g: inputs = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) tf.contrib.layers.fully_connected(inputs, 2, scope='source') tf.contrib.layers.fully_connected(inputs, 2, scope='target') source_vars = tf.contrib.framework.get_model_variables('source') target_vars = tf.contrib.framework.get_model_variables('target') update_op = common.soft_variables_update(source_vars, target_vars, tau) with self.test_session(graph=g) as sess: tf.global_variables_initializer().run() v_s, v_t = sess.run([source_vars, target_vars]) sess.run(update_op) new_v_s, new_v_t = sess.run([source_vars, target_vars]) for i_v_s, i_v_t, n_v_s, n_v_t in zip(v_s, v_t, new_v_s, new_v_t): # Source variables don't change self.assertAllClose(n_v_s, i_v_s) # Target variables are updated self.assertAllClose(n_v_t, tau*i_v_s + (1-tau)*i_v_t)
def _initialize(self): """Initialize the agent. Copies weights from the actor and critic networks to the respective target actor and critic networks. """ common.soft_variables_update(self._critic_network_1.variables, self._target_critic_network_1.variables, tau=1.0) common.soft_variables_update(self._critic_network_2.variables, self._target_critic_network_2.variables, tau=1.0) common.soft_variables_update(self._actor_network.variables, self._target_actor_network.variables, tau=1.0)
def update_target(): common.soft_variables_update( e_enc.variables, e_enc_t.variables, tau=tf_agent.target_update_tau, tau_non_trainable=tf_agent.target_update_tau) common.soft_variables_update( tf_agent._critic_network_1.variables, # pylint: disable=protected-access tf_agent._target_critic_network_1.variables, # pylint: disable=protected-access tau=tf_agent.target_update_tau, tau_non_trainable=tf_agent.target_update_tau) common.soft_variables_update( tf_agent._critic_network_2.variables, # pylint: disable=protected-access tf_agent._target_critic_network_2.variables, # pylint: disable=protected-access tau=tf_agent.target_update_tau, tau_non_trainable=tf_agent.target_update_tau)
def update(self, policy, tau=1.0, sort_variables_by_name=False): """Update the current policy with another policy. This would include copying the variables from the other policy. Args: policy: Another policy it can update from. tau: A float scalar in [0, 1]. When tau is 1.0 (default), we do a hard update. sort_variables_by_name: A bool, when True would sort the variables by name before doing the update. Returns: An TF op to do the update. """ if self.variables(): return common.soft_variables_update( policy.variables(), self.variables(), tau=tau, sort_variables_by_name=sort_variables_by_name) else: return tf.no_op()
def update_partial(self, policy, tau=1.0): """Update the current policy with another policy. This would include copying the variables from the other policy. Args: policy: Another policy it can update from. tau: A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard update. This is used for trainable variables. Returns: An TF op to do the update. """ if self.variables(): policy_vars = policy.variables() return common.soft_variables_update( policy_vars, self.variables()[:len(policy_vars)], tau=tau, tau_non_trainable=None, sort_variables_by_name=True) else: return tf.no_op()
def testUpdateOnlyTargetVariables(self, tau): inputs = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) source_net = tf.keras.layers.Dense(2, name='source_net') target_net = tf.keras.layers.Dense(2, name='target_net') # Force variable creation source_net(inputs) target_net(inputs) source_vars = source_net.trainable_weights target_vars = target_net.trainable_weights self.evaluate(tf.compat.v1.global_variables_initializer()) v_s, v_t = self.evaluate([source_vars, target_vars]) update_op = common.soft_variables_update(source_vars, target_vars, tau) self.evaluate(update_op) new_v_s, new_v_t = self.evaluate([source_vars, target_vars]) for i_v_s, i_v_t, n_v_s, n_v_t in zip(v_s, v_t, new_v_s, new_v_t): # Source variables don't change self.assertAllClose(n_v_s, i_v_s) # Target variables are updated self.assertAllClose(n_v_t, tau * i_v_s + (1 - tau) * i_v_t)
def __init__(self, action_spec, actor_network: Network, critic_network: Network, critic_loss=None, target_entropy=None, initial_log_alpha=0.0, target_update_tau=0.05, target_update_period=1, dqda_clipping=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, gradient_clipping=None, train_step_counter=None, debug_summaries=False, name="SacAlgorithm"): """Create a SacAlgorithm Args: action_spec (nested BoundedTensorSpec): representing the actions. actor_network (Network): The network will be called with call(observation, step_type). critic_network (Network): The network will be called with call(observation, action, step_type). critic_loss (None|OneStepTDLoss): an object for calculating critic loss. If None, a default OneStepTDLoss will be used. initial_log_alpha (float): initial value for variable log_alpha target_entropy (float|None): The target average policy entropy, for updating alpha. target_update_tau (float): Factor for soft update of the target networks. target_update_period (int): Period for soft update of the target networks. dqda_clipping (float): when computing the actor loss, clips the gradient dqda element-wise between [-dqda_clipping, dqda_clipping]. Does not perform clipping if dqda_clipping == 0. actor_optimizer (tf.optimizers.Optimizer): The optimizer for actor. critic_optimizer (tf.optimizers.Optimizer): The optimizer for critic. alpha_optimizer (tf.optimizers.Optimizer): The optimizer for alpha. gradient_clipping (float): Norm length to clip gradients. train_step_counter (tf.Variable): An optional counter to increment every time the a new iteration is started. If None, it will use tf.summary.experimental.get_step(). If this is still None, a counter will be created. debug_summaries (bool): True if debug summaries should be created. name (str): The name of this algorithm. """ critic_network1 = critic_network critic_network2 = critic_network.copy(name='CriticNetwork2') log_alpha = tfa_common.create_variable(name='log_alpha', initial_value=initial_log_alpha, dtype=tf.float32, trainable=True) super().__init__( action_spec, train_state_spec=SacState( share=SacShareState(actor=actor_network.state_spec), actor=SacActorState(critic1=critic_network.state_spec, critic2=critic_network.state_spec), critic=SacCriticState( critic1=critic_network.state_spec, critic2=critic_network.state_spec, target_critic1=critic_network.state_spec, target_critic2=critic_network.state_spec)), action_distribution_spec=actor_network.output_spec, predict_state_spec=actor_network.state_spec, optimizer=[actor_optimizer, critic_optimizer, alpha_optimizer], get_trainable_variables_func=[ lambda: actor_network.trainable_variables, lambda: (critic_network1.trainable_variables + critic_network2. trainable_variables), lambda: [log_alpha] ], gradient_clipping=gradient_clipping, train_step_counter=train_step_counter, debug_summaries=debug_summaries, name=name) self._log_alpha = log_alpha self._actor_network = actor_network self._critic_network1 = critic_network1 self._critic_network2 = critic_network2 self._target_critic_network1 = self._critic_network1.copy( name='TargetCriticNetwork1') self._target_critic_network2 = self._critic_network2.copy( name='TargetCriticNetwork2') self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._alpha_optimizer = alpha_optimizer if critic_loss is None: critic_loss = OneStepTDLoss(debug_summaries=debug_summaries) self._critic_loss = critic_loss flat_action_spec = tf.nest.flatten(self._action_spec) self._is_continuous = tensor_spec.is_continuous(flat_action_spec[0]) if target_entropy is None: target_entropy = np.sum( list( map(dist_utils.calc_default_target_entropy, flat_action_spec))) self._target_entropy = target_entropy self._dqda_clipping = dqda_clipping self._update_target = common.get_target_updater( models=[self._critic_network1, self._critic_network2], target_models=[ self._target_critic_network1, self._target_critic_network2 ], tau=target_update_tau, period=target_update_period) tfa_common.soft_variables_update( self._critic_network1.variables, self._target_critic_network1.variables, tau=1.0) tfa_common.soft_variables_update( self._critic_network2.variables, self._target_critic_network2.variables, tau=1.0)
def update_delayed(): return common.soft_variables_update( self._target_q_network_delayed.variables, self._target_q_network_delayed_2.variables, tau_delayed, tau_non_trainable=1.0)
def after_train(self, training_info): if self._predict_net: tfa_common.soft_variables_update( self._net.variables, self._predict_net.variables, tau=self._net_moving_average_rate)
def _initialize_v1(self): self._q_network.create_variables() if self._target_q_network: self._target_q_network.create_variables() return common.soft_variables_update( self._q_network.variables, self._target_q_network.variables, tau=1.0)
def __init__(self, output_dim, noise_dim=32, input_tensor_spec=None, hidden_layers=(256, ), net: Network = None, net_moving_average_rate=None, entropy_regularization=0., kernel_sharpness=2., mi_weight=None, mi_estimator_cls=MIEstimator, optimizer: tf.optimizers.Optimizer = None, name="Generator"): """Create a Generator. Args: output_dim (int): dimension of output noise_dim (int): dimension of noise input_tensor_spec (nested TensorSpec): spec of inputs. If there is no inputs, this should be None. hidden_layers (tuple): size of hidden layers. net (Network): network for generating outputs from [noise, inputs] or noise (if inputs is None). If None, a default one with hidden_layers will be created net_moving_average_rate (float): If provided, use a moving average version of net to do prediction. This has been shown to be effective for GAN training (arXiv:1907.02544, arXiv:1812.04948). entropy_regularization (float): weight of entropy regularization kernel_sharpness (float): Used only for entropy_regularization > 0. We calcualte the kernel in SVGD as: exp(-kernel_sharpness * reduce_mean((x-y)^2/width)), where width is the elementwise moving average of (x-y)^2 mi_estimator_cls (type): the class of mutual information estimator for maximizing the mutual information between [noise, inputs] and [outputs, inputs]. optimizer (tf.optimizers.Optimizer): optimizer (optional) name (str): name of this generator """ super().__init__(train_state_spec=(), optimizer=optimizer, name=name) self._noise_dim = noise_dim self._entropy_regularization = entropy_regularization if entropy_regularization == 0: self._grad_func = self._ml_grad else: self._grad_func = self._stein_grad self._kernel_width_averager = AdaptiveAverager( tensor_spec=tf.TensorSpec(shape=(output_dim, ))) self._kernel_sharpness = kernel_sharpness noise_spec = tf.TensorSpec(shape=[noise_dim]) if net is None: net_input_spec = noise_spec if input_tensor_spec is not None: net_input_spec = [net_input_spec, input_tensor_spec] net = EncodingNetwork( name="Generator", input_tensor_spec=net_input_spec, fc_layer_params=hidden_layers, last_layer_size=output_dim) self._mi_estimator = None self._input_tensor_spec = input_tensor_spec if mi_weight is not None: x_spec = noise_spec y_spec = tf.TensorSpec((output_dim, )) if input_tensor_spec is not None: x_spec = [x_spec, input_tensor_spec] self._mi_estimator = mi_estimator_cls( x_spec, y_spec, sampler='shift') self._mi_weight = mi_weight self._net = net self._predict_net = None self._net_moving_average_rate = net_moving_average_rate if net_moving_average_rate: self._predict_net = net.copy(name="Genrator_average") tfa_common.soft_variables_update( self._net.variables, self._predict_net.variables, tau=1.0)
def update(): """Update target network.""" critic_update = common.soft_variables_update( sc_net.variables, target_sc_net.variables, tau) return critic_update
def train_eval( root_dir, random_seed=None, # Dataset params domain_name='cartpole', task_name='swingup', frame_shape=(84, 84, 3), image_aug_type='random_shifting', # None/'random_shifting' frame_stack=3, action_repeat=4, # Params for learning num_env_steps=1000000, learn_ceb=True, use_augmented_q=False, # Params for CEB e_ctor=encoders.FRNConv, e_head_ctor=encoders.MVNormalDiagParamHead, b_ctor=encoders.FRNConv, b_head_ctor=encoders.MVNormalDiagParamHead, conv_feature_dim=50, # deterministic feature used by actor/critic/ceb ceb_feature_dim=50, ceb_action_condition=True, ceb_backward_encode_rewards=True, initial_feature_step=0, feature_lr=3e-4, feature_lr_schedule=None, ceb_beta=0.01, ceb_beta_schedule=None, ceb_generative_ratio=0.0, ceb_generative_items=None, feature_grad_clip=None, enc_ema_tau=0.05, # if enc_ema_tau=None, ceb also learns backend encoder use_critic_grad=True, # Params for SAC actor_kernel_init='glorot_uniform', normal_proj_net=sac_agent.sac_normal_projection_net, critic_kernel_init='glorot_uniform', critic_last_kernel_init='glorot_uniform', actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Params for collect collect_every=1, initial_collect_steps=1000, collect_steps_per_iteration=1, replay_buffer_capacity=100000, # Params for target update target_update_tau=0.005, target_update_period=1, # Params for train batch_size=256, actor_learning_rate=3e-4, actor_lr_schedule=None, critic_learning_rate=3e-4, critic_lr_schedule=None, alpha_learning_rate=3e-4, alpha_lr_schedule=None, td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error, gamma=0.99, reward_scale_factor=1.0, gradient_clipping=None, use_tf_functions=True, drivers_in_graph=True, # Params for eval num_eval_episodes=10, eval_env_interval=5000, # number of env steps greedy_eval_policy=True, train_next_frame_decoder=False, # Params for summaries and logging baseline_log_fn=None, checkpoint_env_interval=100000, # number of env steps log_env_interval=1000, # number of env steps summary_interval=1000, image_summary_interval=0, summaries_flush_secs=10, debug_summaries=False, summarize_grads_and_vars=False, eval_metrics_callback=None): """train and eval for PI-SAC.""" if random_seed is not None: tf.compat.v1.set_random_seed(random_seed) np.random.seed(random_seed) # Load baseline logs and write to tensorboard if baseline_log_fn is not None: baseline_log_fn(root_dir, domain_name, task_name, action_repeat) if root_dir is None: raise AttributeError('train_eval requires a root_dir.') # Set iterations and intervals to be computed relative to the number of # environment steps rather than the number of gradient steps. num_iterations = ( num_env_steps * collect_every // collect_steps_per_iteration + (initial_feature_step)) checkpoint_interval = (checkpoint_env_interval * collect_every // collect_steps_per_iteration) eval_interval = (eval_env_interval * collect_every // collect_steps_per_iteration) log_interval = (log_env_interval * collect_every // collect_steps_per_iteration) logging.info('num_env_steps = %d (env steps)', num_env_steps) logging.info('initial_feature_step = %d (gradient steps)', initial_feature_step) logging.info('num_iterations = %d (gradient steps)', num_iterations) logging.info('checkpoint interval (env steps) = %d', checkpoint_env_interval) logging.info('checkpoint interval (gradient steps) = %d', checkpoint_interval) logging.info('eval interval (env steps) = %d', eval_env_interval) logging.info('eval interval (gradient steps) = %d', eval_interval) logging.info('log interval (env steps) = %d', log_env_interval) logging.info('log interval (gradient steps) = %d', log_interval) root_dir = os.path.expanduser(root_dir) summary_writer = tf.compat.v2.summary.create_file_writer( root_dir, flush_millis=summaries_flush_secs * 1000) summary_writer.set_as_default() eval_histograms = [ pisac_metric_utils.ReturnHistogram(buffer_size=num_eval_episodes), ] eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes), pisac_metric_utils.ReturnStddevMetric(buffer_size=num_eval_episodes), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes) ] # create training environment render_configs = { 'height': frame_shape[0], 'width': frame_shape[1], 'camera_id': dict(quadruped=2).get(domain_name, 0), } tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(domain_name, task_name, render_configs, frame_stack, action_repeat)) eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(domain_name, task_name, render_configs, frame_stack, action_repeat)) # Define global step g_step = common.create_variable('g_step') # Spec ims_shape = frame_shape[:2] + (frame_shape[2] * frame_stack, ) ims_spec = tf.TensorSpec(shape=ims_shape, dtype=tf.uint8) conv_feature_spec = tf.TensorSpec(shape=(conv_feature_dim, ), dtype=tf.float32) action_spec = tf_env.action_spec() # Forward encoder e_enc = e_ctor(ims_spec, output_dim=conv_feature_dim, name='e') e_enc_t = e_ctor(ims_spec, output_dim=conv_feature_dim, name='e_t') e_enc.create_variables() e_enc_t.create_variables() common.soft_variables_update(e_enc.variables, e_enc_t.variables, tau=1.0, tau_non_trainable=1.0) # Forward encoder head if e_head_ctor is None: e_head = None else: stacked_action_spec = tensor_spec.BoundedTensorSpec( action_spec.shape[:-1] + (action_spec.shape[-1] * frame_stack), action_spec.dtype, action_spec.minimum.tolist() * frame_stack, action_spec.maximum.tolist() * frame_stack, action_spec.name) e_head_spec = [conv_feature_spec, stacked_action_spec ] if ceb_action_condition else conv_feature_spec e_head = e_head_ctor(e_head_spec, output_dim=ceb_feature_dim, name='e_head') e_head.create_variables() # Backward encoder b_enc = b_ctor(ims_spec, output_dim=conv_feature_dim, name='b') b_enc.create_variables() # Backward encoder head if b_head_ctor is None: b_head = None else: stacked_reward_spec = tf.TensorSpec(shape=(frame_stack, ), dtype=tf.float32) b_head_spec = [conv_feature_spec, stacked_reward_spec ] if ceb_backward_encode_rewards else conv_feature_spec b_head = b_head_ctor(b_head_spec, output_dim=ceb_feature_dim, name='b_head') b_head.create_variables() # future decoder for generative formulation future_deconv = None future_reward_mlp = None y_decoders = None if ceb_generative_ratio > 0.0: future_deconv = utils.SimpleDeconv(conv_feature_spec, output_tensor_spec=ims_spec) future_deconv.create_variables() future_reward_mlp = utils.MLP(conv_feature_spec, hidden_dims=(ceb_feature_dim, ceb_feature_dim // 2, frame_stack)) future_reward_mlp.create_variables() y_decoders = [future_deconv, future_reward_mlp] m_vars = e_enc.trainable_variables if enc_ema_tau is None: m_vars += b_enc.trainable_variables else: # do not train b_enc common.soft_variables_update(e_enc.variables, b_enc.variables, tau=1.0, tau_non_trainable=1.0) if e_head_ctor is not None: m_vars += e_head.trainable_variables if b_head_ctor is not None: m_vars += b_head.trainable_variables if ceb_generative_ratio > 0.0: m_vars += future_deconv.trainable_variables m_vars += future_reward_mlp.trainable_variables feature_lr_fn = schedule_utils.get_schedule_fn(base=feature_lr, sched=feature_lr_schedule, step=g_step) m_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=feature_lr_fn) # CEB beta schedule, e.q. 'berp@0:1.0:1000_10000:0.3:0' beta_fn = schedule_utils.get_schedule_fn(base=ceb_beta, sched=ceb_beta_schedule, step=g_step) def img_pred_summary_fn(obs, pred): utils.replay_summary('y0', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)( obs, None) utils.replay_summary('y0_pred', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)( pred, None) utils.replay_summary('y0_pred_diff', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)( ((obs - pred) / 2.0 + 0.5), None) ceb = ceb_task.CEB(beta_fn=beta_fn, generative_ratio=ceb_generative_ratio, generative_items=ceb_generative_items, step_counter=g_step, img_pred_summary_fn=img_pred_summary_fn) m_ceb = ceb_task.CEBTask( ceb, e_enc, b_enc, forward_head=e_head, backward_head=b_head, y_decoders=y_decoders, learn_backward_enc=(enc_ema_tau is None), action_condition=ceb_action_condition, backward_encode_rewards=ceb_backward_encode_rewards, optimizer=m_optimizer, grad_clip=feature_grad_clip, global_step=g_step) if train_next_frame_decoder: ns_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3) next_frame_deconv = utils.SimpleDeconv(conv_feature_spec, output_tensor_spec=ims_spec) next_frame_decoder = utils.PixelDecoder( next_frame_deconv, optimizer=ns_optimizer, step_counter=g_step, image_summary_interval=image_summary_interval, frame_stack=frame_stack) next_frame_deconv.create_variables() # Agent training actor_lr_fn = schedule_utils.get_schedule_fn(base=actor_learning_rate, sched=actor_lr_schedule, step=g_step) critic_lr_fn = schedule_utils.get_schedule_fn(base=critic_learning_rate, sched=critic_lr_schedule, step=g_step) alpha_lr_fn = schedule_utils.get_schedule_fn(base=alpha_learning_rate, sched=alpha_lr_schedule, step=g_step) actor_net = actor_distribution_network.ActorDistributionNetwork( conv_feature_spec, action_spec, kernel_initializer=actor_kernel_init, fc_layer_params=actor_fc_layers, activation_fn=tf.keras.activations.relu, continuous_projection_net=normal_proj_net) critic_net = critic_network.CriticNetwork( (conv_feature_spec, action_spec), observation_fc_layer_params=critic_obs_fc_layers, action_fc_layer_params=critic_action_fc_layers, joint_fc_layer_params=critic_joint_fc_layers, activation_fn=tf.nn.relu, kernel_initializer=critic_kernel_init, last_kernel_initializer=critic_last_kernel_init) tf_agent = sac_agent.SacAgent( ts.time_step_spec(observation_spec=conv_feature_spec), action_spec, actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=actor_lr_fn), critic_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=critic_lr_fn), alpha_optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=alpha_lr_fn), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=g_step) tf_agent.initialize() env_steps = tf_metrics.EnvironmentSteps(prefix='Train') average_return = tf_metrics.AverageReturnMetric( prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ tf_metrics.NumberOfEpisodes(prefix='Train'), env_steps, average_return, tf_metrics.AverageEpisodeLengthMetric(prefix='Train', buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), tf_metrics.AverageReturnMetric(name='LatestReturn', prefix='Train', buffer_size=1, batch_size=tf_env.batch_size) ] # Collect and eval policies initial_collect_policy = random_tf_policy.RandomTFPolicy( tf_env.time_step_spec(), action_spec) eval_policy = tf_agent.policy if greedy_eval_policy: eval_policy = greedy_policy.GreedyPolicy(eval_policy) def obs_to_feature(observation): feature, _ = e_enc(observation['pixels'], training=False) return tf.stop_gradient(feature) eval_policy = FeaturePolicy(policy=eval_policy, time_step_spec=tf_env.time_step_spec(), obs_to_feature_fn=obs_to_feature) collect_policy = FeaturePolicy(policy=tf_agent.collect_policy, time_step_spec=tf_env.time_step_spec(), obs_to_feature_fn=obs_to_feature) # Make the replay buffer. replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=collect_policy.trajectory_spec, batch_size=1, max_length=replay_buffer_capacity) replay_observer = [replay_buffer.add_batch] # Checkpoints train_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'train'), agent=tf_agent, actor_net=actor_net, critic_net=critic_net, global_step=g_step, metrics=tfa_metric_utils.MetricsGroup(train_metrics, 'train_metrics')) train_checkpointer.initialize_or_restore() policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'policy'), policy=eval_policy, global_step=g_step) policy_checkpointer.initialize_or_restore() rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'replay_buffer'), max_to_keep=1, replay_buffer=replay_buffer, global_step=g_step) rb_checkpointer.initialize_or_restore() if learn_ceb: d = dict() if future_deconv is not None: d.update(future_deconv=future_deconv) if future_reward_mlp is not None: d.update(future_reward_mlp=future_reward_mlp) model_ckpt = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'model'), forward_encoder=e_enc, forward_encoder_target=e_enc_t, forward_head=e_head, backward_encoder=b_enc, backward_head=b_head, global_step=g_step, **d) else: model_ckpt = common.Checkpointer(ckpt_dir=os.path.join( root_dir, 'model'), forward_encoder=e_enc, forward_encoder_target=e_enc_t, global_step=g_step) model_ckpt.initialize_or_restore() if train_next_frame_decoder: next_frame_decoder_ckpt = common.Checkpointer( ckpt_dir=os.path.join(root_dir, 'next_frame_decoder'), next_frame_decoder=next_frame_decoder, next_frame_deconv=next_frame_deconv, global_step=g_step) next_frame_decoder_ckpt.initialize_or_restore() if use_tf_functions and not drivers_in_graph: collect_policy.action = common.function(collect_policy.action) initial_collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, initial_collect_policy, observers=replay_observer + train_metrics, num_steps=initial_collect_steps) collect_driver = dynamic_step_driver.DynamicStepDriver( tf_env, collect_policy, observers=replay_observer + train_metrics, num_steps=collect_steps_per_iteration) if use_tf_functions and drivers_in_graph: initial_collect_driver.run = common.function( initial_collect_driver.run) collect_driver.run = common.function(collect_driver.run) # Collect initial replay data. if env_steps.result() == 0 or replay_buffer.num_frames() == 0: qj(initial_collect_steps, 'Initializing replay buffer by collecting random experience', tic=1) initial_collect_driver.run() for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) qj(s='Done initializing replay buffer', toc=1) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) time_acc = 0 env_steps_before = env_steps.result().numpy() paddings = tf.constant([[4, 4], [4, 4], [0, 0]]) def random_shifting(traj, meta): x0 = traj.observation['pixels'][0] x1 = traj.observation['pixels'][1] y0 = traj.observation['pixels'][frame_stack] y1 = traj.observation['pixels'][frame_stack + 1] x0 = tf.pad(x0, paddings, 'SYMMETRIC') x1 = tf.pad(x1, paddings, 'SYMMETRIC') y0 = tf.pad(y0, paddings, 'SYMMETRIC') y1 = tf.pad(y1, paddings, 'SYMMETRIC') x0a = tf.image.random_crop(x0, ims_shape) x1a = tf.image.random_crop(x1, ims_shape) x0 = tf.image.random_crop(x0, ims_shape) x1 = tf.image.random_crop(x1, ims_shape) y0 = tf.image.random_crop(y0, ims_shape) y1 = tf.image.random_crop(y1, ims_shape) return (traj, (x0, x1, x0a, x1a, y0, y1)), meta # Dataset generates trajectories with shape [B, T, ...] num_steps = frame_stack + 2 with tf.device('/cpu:0'): if image_aug_type == 'random_shifting': dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=num_steps).unbatch().filter( utils.filter_invalid_transition).map( random_shifting, num_parallel_calls=3).batch(batch_size).map( utils.replay_summary( 'replay/filtered', order_frame_stack=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval, has_augmentations=True)) elif image_aug_type is None: dataset = replay_buffer.as_dataset( sample_batch_size=batch_size, num_steps=num_steps).unbatch().filter( utils.filter_invalid_transition).batch(batch_size).map( utils.replay_summary( 'replay/filtered', order_frame_stack=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval, has_augmentations=False)) else: raise NotImplementedError iterator_nstep = iter(dataset) def model_train_step(experience): if image_aug_type == 'random_shifting': experience, cropped_frames = experience x0, x1, _, _, y0, y1 = cropped_frames r0, r1, a0, a1 = utils.split_xy(experience, frame_stack, rewards_n_actions_only=True) x0 = x0[:, None, ...] x1 = x1[:, None, ...] y0 = y0[:, None, ...] y1 = y1[:, None, ...] elif image_aug_type is None: x0, x1, y0, y1, r0, r1, a0, a1 = utils.split_xy( experience, frame_stack, rewards_n_actions_only=False) else: raise NotImplementedError # Flatten stacked actions action_shape = a0.shape.as_list() a0 = tf.reshape(a0, [action_shape[0], action_shape[1], -1]) a1 = tf.reshape(a1, [action_shape[0], action_shape[1], -1]) if image_summary_interval > 0: utils.replay_summary( 'ceb/x0', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)(x0, None) utils.replay_summary( 'ceb/x1', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)(x1, None) utils.replay_summary( 'ceb/y0', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)(y0, None) utils.replay_summary( 'ceb/y1', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)(y1, None) ceb_loss, feat_x0, zx0 = m_ceb.train(x0, a0, y0, y1, r0, r1, m_vars) if train_next_frame_decoder: # zx0: [B, 1, Z] zx0 = tf.squeeze(zx0, axis=1) # y0: [B, 1, H, W, Cxframe_stack] next_obs = tf.cast(tf.squeeze(y0, axis=1), tf.float32) / 255.0 next_frame_decoder.train(next_obs, tf.stop_gradient(zx0)) if enc_ema_tau is not None: common.soft_variables_update(e_enc.variables, b_enc.variables, tau=enc_ema_tau, tau_non_trainable=enc_ema_tau) def agent_train_step(experience): # preprocess experience if image_aug_type == 'random_shifting': experience, cropped_frames = experience x0, x1, x0a, x1a, y0, y1 = cropped_frames experience = tf.nest.map_structure( lambda t: composite.slice_to(t, axis=1, end=2), experience) time_steps, actions, next_time_steps = ( tf_agent.experience_to_transitions(experience)) # pylint: disable=protected-access elif image_aug_type is None: experience = tf.nest.map_structure( lambda t: composite.slice_to(t, axis=1, end=2), experience) time_steps, actions, next_time_steps = ( tf_agent.experience_to_transitions(experience)) # pylint: disable=protected-access x0 = time_steps.observation['pixels'] x1 = next_time_steps.observation['pixels'] else: raise NotImplementedError tf_agent.train_pix(time_steps, actions, next_time_steps, x0, x1, x0a=x0a if use_augmented_q else None, x1a=x1a if use_augmented_q else None, e_enc=e_enc, e_enc_t=e_enc_t, q_aug=use_augmented_q, use_critic_grad=use_critic_grad) def checkpoint(step): rb_checkpointer.save(global_step=step) train_checkpointer.save(global_step=step) policy_checkpointer.save(global_step=step) model_ckpt.save(global_step=step) if train_next_frame_decoder: next_frame_decoder_ckpt.save(global_step=step) def evaluate(): # Override outer record_if that may be out of sync with respect to the # env_steps.result() value used for the summay step. with tf.compat.v2.summary.record_if(True): qj(g_step.numpy(), 'Starting eval at step', tic=1) results = pisac_metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, histograms=eval_histograms, num_episodes=num_eval_episodes, train_step=env_steps.result(), summary_writer=summary_writer, summary_prefix='Eval', use_function=drivers_in_graph, ) if eval_metrics_callback is not None: eval_metrics_callback(results, env_steps.result()) tfa_metric_utils.log_metrics(eval_metrics) qj(s='Finished eval', toc=1) def update_target(): common.soft_variables_update( e_enc.variables, e_enc_t.variables, tau=tf_agent.target_update_tau, tau_non_trainable=tf_agent.target_update_tau) common.soft_variables_update( tf_agent._critic_network_1.variables, # pylint: disable=protected-access tf_agent._target_critic_network_1.variables, # pylint: disable=protected-access tau=tf_agent.target_update_tau, tau_non_trainable=tf_agent.target_update_tau) common.soft_variables_update( tf_agent._critic_network_2.variables, # pylint: disable=protected-access tf_agent._target_critic_network_2.variables, # pylint: disable=protected-access tau=tf_agent.target_update_tau, tau_non_trainable=tf_agent.target_update_tau) if use_tf_functions: if learn_ceb: m_ceb.train = common.function(m_ceb.train) model_train_step = common.function(model_train_step) agent_train_step = common.function(agent_train_step) tf_agent.train_pix = common.function(tf_agent.train_pix) update_target = common.function(update_target) if train_next_frame_decoder: next_frame_decoder.train = common.function( next_frame_decoder.train) if not learn_ceb and initial_feature_step > 0: raise ValueError('Not learning CEB but initial_feature_step > 0') with tf.summary.record_if( lambda: tf.math.equal(g_step % summary_interval, 0)): if learn_ceb and g_step.numpy() < initial_feature_step: qj(initial_feature_step, 'Pretraining CEB...', tic=1) for _ in range(g_step.numpy(), initial_feature_step): with tf.name_scope('LearningRates'): tf.summary.scalar(name='CEB learning rate', data=feature_lr_fn(), step=g_step) experience, _ = next(iterator_nstep) model_train_step(experience) g_step.assign_add(1) qj(s='Done pretraining CEB.', toc=1) first_step = True for _ in range(g_step.numpy(), num_iterations): g_step_val = g_step.numpy() start_time = time.time() with tf.summary.record_if( lambda: tf.math.equal(g_step % summary_interval, 0)): with tf.name_scope('LearningRates'): tf.summary.scalar(name='Actor learning rate', data=actor_lr_fn(), step=g_step) tf.summary.scalar(name='Critic learning rate', data=critic_lr_fn(), step=g_step) tf.summary.scalar(name='Alpha learning rate', data=alpha_lr_fn(), step=g_step) if learn_ceb: tf.summary.scalar(name='CEB learning rate', data=feature_lr_fn(), step=g_step) with tf.name_scope('Train'): tf.summary.scalar(name='StepsVsEnvironmentSteps', data=env_steps.result(), step=g_step) tf.summary.scalar(name='StepsVsAverageReturn', data=average_return.result(), step=g_step) if g_step_val % collect_every == 0: time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) experience, _ = next(iterator_nstep) agent_train_step(experience) if (g_step_val - initial_feature_step) % tf_agent.target_update_period == 0: update_target() if learn_ceb: model_train_step(experience) time_acc += time.time() - start_time # Increment global step counter. g_step.assign_add(1) g_step_val = g_step.numpy() if (g_step_val - initial_feature_step) % log_interval == 0: for train_metric in train_metrics: train_metric.tf_summaries(train_step=env_steps.result()) logging.info('env steps = %d, average return = %f', env_steps.result(), average_return.result()) env_steps_per_sec = (env_steps.result().numpy() - env_steps_before) / time_acc logging.info('%.3f env steps/sec', env_steps_per_sec) tf.compat.v2.summary.scalar(name='env_steps_per_sec', data=env_steps_per_sec, step=env_steps.result()) time_acc = 0 env_steps_before = env_steps.result().numpy() if (g_step_val - initial_feature_step) % eval_interval == 0: eval_start_time = time.time() evaluate() logging.info('eval time %.3f sec', time.time() - eval_start_time) if (g_step_val - initial_feature_step) % checkpoint_interval == 0: checkpoint(g_step_val) # Write gin config to Tensorboard if first_step: summ = utils.Summ(0, root_dir) conf = gin.operative_config_str() conf = ' ' + conf.replace('\n', '\n ') summ.text('gin/config', conf) summ.flush() first_step = False # Final checkpoint. checkpoint(g_step.numpy()) # Final evaluation. evaluate()
def model_train_step(experience): if image_aug_type == 'random_shifting': experience, cropped_frames = experience x0, x1, _, _, y0, y1 = cropped_frames r0, r1, a0, a1 = utils.split_xy(experience, frame_stack, rewards_n_actions_only=True) x0 = x0[:, None, ...] x1 = x1[:, None, ...] y0 = y0[:, None, ...] y1 = y1[:, None, ...] elif image_aug_type is None: x0, x1, y0, y1, r0, r1, a0, a1 = utils.split_xy( experience, frame_stack, rewards_n_actions_only=False) else: raise NotImplementedError # Flatten stacked actions action_shape = a0.shape.as_list() a0 = tf.reshape(a0, [action_shape[0], action_shape[1], -1]) a1 = tf.reshape(a1, [action_shape[0], action_shape[1], -1]) if image_summary_interval > 0: utils.replay_summary( 'ceb/x0', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)(x0, None) utils.replay_summary( 'ceb/x1', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)(x1, None) utils.replay_summary( 'ceb/y0', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)(y0, None) utils.replay_summary( 'ceb/y1', g_step, reshape=True, frame_stack=frame_stack, image_summary_interval=image_summary_interval)(y1, None) ceb_loss, feat_x0, zx0 = m_ceb.train(x0, a0, y0, y1, r0, r1, m_vars) if train_next_frame_decoder: # zx0: [B, 1, Z] zx0 = tf.squeeze(zx0, axis=1) # y0: [B, 1, H, W, Cxframe_stack] next_obs = tf.cast(tf.squeeze(y0, axis=1), tf.float32) / 255.0 next_frame_decoder.train(next_obs, tf.stop_gradient(zx0)) if enc_ema_tau is not None: common.soft_variables_update(e_enc.variables, b_enc.variables, tau=enc_ema_tau, tau_non_trainable=enc_ema_tau)
def __init__(self, action_spec, actor_network: Network, critic_network: Network, ou_stddev=0.2, ou_damping=0.15, critic_loss=None, target_update_tau=0.05, target_update_period=1, dqda_clipping=None, actor_optimizer=None, critic_optimizer=None, gradient_clipping=None, train_step_counter=None, debug_summaries=False, name="DdpgAlgorithm"): """ Args: action_spec (nested BoundedTensorSpec): representing the actions. actor_network (Network): The network will be called with call(observation, step_type). critic_network (Network): The network will be called with call(observation, action, step_type). ou_stddev (float): Standard deviation for the Ornstein-Uhlenbeck (OU) noise added in the default collect policy. ou_damping (float): Damping factor for the OU noise added in the default collect policy. critic_loss (None|OneStepTDLoss): an object for calculating critic loss. If None, a default OneStepTDLoss will be used. target_update_tau (float): Factor for soft update of the target networks. target_update_period (int): Period for soft update of the target networks. dqda_clipping (float): when computing the actor loss, clips the gradient dqda element-wise between [-dqda_clipping, dqda_clipping]. Does not perform clipping if dqda_clipping == 0. actor_optimizer (tf.optimizers.Optimizer): The optimizer for actor. critic_optimizer (tf.optimizers.Optimizer): The optimizer for actor. gradient_clipping (float): Norm length to clip gradients. train_step_counter (tf.Variable): An optional counter to increment every time the a new iteration is started. If None, it will use tf.summary.experimental.get_step(). If this is still None, a counter will be created. debug_summaries (bool): True if debug summaries should be created. name (str): The name of this algorithm. """ train_state_spec = DdpgState( actor=DdpgActorState(actor=actor_network.state_spec, critic=critic_network.state_spec), critic=DdpgCriticState(critic=critic_network.state_spec, target_actor=actor_network.state_spec, target_critic=critic_network.state_spec)) super().__init__(action_spec, train_state_spec=train_state_spec, action_distribution_spec=action_spec, optimizer=[actor_optimizer, critic_optimizer], get_trainable_variables_func=[ lambda: actor_network.trainable_variables, lambda: critic_network.trainable_variables ], gradient_clipping=gradient_clipping, train_step_counter=train_step_counter, debug_summaries=debug_summaries, name=name) self._actor_network = actor_network self._critic_network = critic_network self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._target_actor_network = actor_network.copy( name='target_actor_network') self._target_critic_network = critic_network.copy( name='target_critic_network') self._ou_stddev = ou_stddev self._ou_damping = ou_damping if critic_loss is None: critic_loss = OneStepTDLoss(debug_summaries=debug_summaries) self._critic_loss = critic_loss self._ou_process = self._create_ou_process(ou_stddev, ou_damping) self._update_target = common.get_target_updater( models=[self._actor_network, self._critic_network], target_models=[ self._target_actor_network, self._target_critic_network ], tau=target_update_tau, period=target_update_period) self._dqda_clipping = dqda_clipping tfa_common.soft_variables_update(self._critic_network.variables, self._target_critic_network.variables, tau=1.0) tfa_common.soft_variables_update(self._actor_network.variables, self._target_actor_network.variables, tau=1.0)
def update(): return common_utils.soft_variables_update( self._q_network.variables, self._target_q_network.variables, tau)
def update(): return common.soft_variables_update( self._q_network.variables, self._target_q_network.variables, tau, tau_non_trainable=1.0)
def update(): return common.soft_variables_update( self.QMIXNet.variables, self.TargetQMIXNet.variables, tau, tau_non_trainable=1.0)
def _initialize(self): common.soft_variables_update( self._q_network.variables, self._target_q_network.variables, tau=1.0)
def update(): return tfagents_common.soft_variables_update( self._value_network.variables, self._target_network.variables, tau, tau_non_trainable=1.0)
def _initialize(self): tfagents_common.soft_variables_update(self._value_network.variables, self._target_network.variables, tau=1.0)
def update(): """Update target network.""" critic_update = common.soft_variables_update( self._safety_critic_network.variables, self._target_safety_critic_network.variables, tau) return critic_update