Exemplo n.º 1
0
            def update():
                """Update target network."""
                critic_update_1 = common.soft_variables_update(
                    self._critic_network_1.variables,
                    self._target_critic_network_1.variables,
                    tau,
                    tau_non_trainable=1.0)

                critic_2_update_vars = common.deduped_network_variables(
                    self._critic_network_2, self._critic_network_1)

                target_critic_2_update_vars = common.deduped_network_variables(
                    self._target_critic_network_2,
                    self._target_critic_network_1)

                critic_update_2 = common.soft_variables_update(
                    critic_2_update_vars,
                    target_critic_2_update_vars,
                    tau,
                    tau_non_trainable=1.0)

                return tf.group(critic_update_1, critic_update_2)
Exemplo n.º 2
0
    def testUpdateOnlyTargetVariables(self, tau):
        inputs = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
        tf.contrib.layers.fully_connected(inputs, 2, scope='source')
        tf.contrib.layers.fully_connected(inputs, 2, scope='target')

        source_vars = tf.contrib.framework.get_model_variables('source')
        target_vars = tf.contrib.framework.get_model_variables('target')
        update_op = common.soft_variables_update(source_vars, target_vars, tau)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        v_s, v_t = self.evaluate([source_vars, target_vars])
        self.evaluate(update_op)
        new_v_s, new_v_t = self.evaluate([source_vars, target_vars])
        for i_v_s, i_v_t, n_v_s, n_v_t in zip(v_s, v_t, new_v_s, new_v_t):
            # Source variables don't change
            self.assertAllClose(n_v_s, i_v_s)
            # Target variables are updated
            self.assertAllClose(n_v_t, tau * i_v_s + (1 - tau) * i_v_t)
Exemplo n.º 3
0
    def _initialize(self):
        """Returns an op to initialize the agent.

    Copies weights from the Q networks to the target Q network.
    """
        common.soft_variables_update(self._critic_network_1.variables,
                                     self._target_critic_network_1.variables,
                                     tau=1.0)
        common.soft_variables_update(self._critic_network_2.variables,
                                     self._target_critic_network_2.variables,
                                     tau=1.0)
        if self._critic_network_no_entropy_1 is not None:
            common.soft_variables_update(
                self._critic_network_no_entropy_1.variables,
                self._target_critic_network_no_entropy_1.variables,
                tau=1.0)
            common.soft_variables_update(
                self._critic_network_no_entropy_2.variables,
                self._target_critic_network_no_entropy_2.variables,
                tau=1.0)
Exemplo n.º 4
0
 def _initialize(self):
     common.soft_variables_update(self._q_network.variables,
                                  self._target_q_network.variables,
                                  tau=1.0)
     if self._enable_td3:
         common.soft_variables_update(
             self._q_network.variables,
             self._target_q_network_delayed.variables,
             tau=1.0)
         common.soft_variables_update(
             self._q_network.variables,
             self._target_q_network_delayed_2.variables,
             tau=1.0)
Exemplo n.º 5
0
  def testUpdateOnlyTargetVariables(self, tau):
    with tf.Graph().as_default() as g:
      inputs = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
      tf.contrib.layers.fully_connected(inputs, 2, scope='source')
      tf.contrib.layers.fully_connected(inputs, 2, scope='target')

      source_vars = tf.contrib.framework.get_model_variables('source')
      target_vars = tf.contrib.framework.get_model_variables('target')
      update_op = common.soft_variables_update(source_vars, target_vars, tau)
      with self.test_session(graph=g) as sess:
        tf.global_variables_initializer().run()
        v_s, v_t = sess.run([source_vars, target_vars])
        sess.run(update_op)
        new_v_s, new_v_t = sess.run([source_vars, target_vars])
        for i_v_s, i_v_t, n_v_s, n_v_t in zip(v_s, v_t, new_v_s, new_v_t):
          # Source variables don't change
          self.assertAllClose(n_v_s, i_v_s)
          # Target variables are updated
          self.assertAllClose(n_v_t, tau*i_v_s + (1-tau)*i_v_t)
Exemplo n.º 6
0
    def _initialize(self):
        """Initialize the agent.

    Copies weights from the actor and critic networks to the respective
    target actor and critic networks.
    """
        common.soft_variables_update(self._critic_network_1.variables,
                                     self._target_critic_network_1.variables,
                                     tau=1.0)
        common.soft_variables_update(self._critic_network_2.variables,
                                     self._target_critic_network_2.variables,
                                     tau=1.0)
        common.soft_variables_update(self._actor_network.variables,
                                     self._target_actor_network.variables,
                                     tau=1.0)
Exemplo n.º 7
0
 def update_target():
     common.soft_variables_update(
         e_enc.variables,
         e_enc_t.variables,
         tau=tf_agent.target_update_tau,
         tau_non_trainable=tf_agent.target_update_tau)
     common.soft_variables_update(
         tf_agent._critic_network_1.variables,  # pylint: disable=protected-access
         tf_agent._target_critic_network_1.variables,  # pylint: disable=protected-access
         tau=tf_agent.target_update_tau,
         tau_non_trainable=tf_agent.target_update_tau)
     common.soft_variables_update(
         tf_agent._critic_network_2.variables,  # pylint: disable=protected-access
         tf_agent._target_critic_network_2.variables,  # pylint: disable=protected-access
         tau=tf_agent.target_update_tau,
         tau_non_trainable=tf_agent.target_update_tau)
Exemplo n.º 8
0
    def update(self, policy, tau=1.0, sort_variables_by_name=False):
        """Update the current policy with another policy.

    This would include copying the variables from the other policy.

    Args:
      policy: Another policy it can update from.
      tau: A float scalar in [0, 1]. When tau is 1.0 (default), we do a hard
      update.
      sort_variables_by_name: A bool, when True would sort the variables by name
      before doing the update.
    Returns:
      An TF op to do the update.
    """
        if self.variables():
            return common.soft_variables_update(
                policy.variables(),
                self.variables(),
                tau=tau,
                sort_variables_by_name=sort_variables_by_name)
        else:
            return tf.no_op()
Exemplo n.º 9
0
    def update_partial(self, policy, tau=1.0):
        """Update the current policy with another policy.

    This would include copying the variables from the other policy.

    Args:
      policy: Another policy it can update from.
      tau: A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard
        update. This is used for trainable variables.

    Returns:
      An TF op to do the update.
    """
        if self.variables():
            policy_vars = policy.variables()
            return common.soft_variables_update(
                policy_vars,
                self.variables()[:len(policy_vars)],
                tau=tau,
                tau_non_trainable=None,
                sort_variables_by_name=True)
        else:
            return tf.no_op()
Exemplo n.º 10
0
    def testUpdateOnlyTargetVariables(self, tau):
        inputs = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
        source_net = tf.keras.layers.Dense(2, name='source_net')
        target_net = tf.keras.layers.Dense(2, name='target_net')

        # Force variable creation
        source_net(inputs)
        target_net(inputs)

        source_vars = source_net.trainable_weights
        target_vars = target_net.trainable_weights

        self.evaluate(tf.compat.v1.global_variables_initializer())
        v_s, v_t = self.evaluate([source_vars, target_vars])

        update_op = common.soft_variables_update(source_vars, target_vars, tau)
        self.evaluate(update_op)
        new_v_s, new_v_t = self.evaluate([source_vars, target_vars])

        for i_v_s, i_v_t, n_v_s, n_v_t in zip(v_s, v_t, new_v_s, new_v_t):
            # Source variables don't change
            self.assertAllClose(n_v_s, i_v_s)
            # Target variables are updated
            self.assertAllClose(n_v_t, tau * i_v_s + (1 - tau) * i_v_t)
Exemplo n.º 11
0
    def __init__(self,
                 action_spec,
                 actor_network: Network,
                 critic_network: Network,
                 critic_loss=None,
                 target_entropy=None,
                 initial_log_alpha=0.0,
                 target_update_tau=0.05,
                 target_update_period=1,
                 dqda_clipping=None,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 alpha_optimizer=None,
                 gradient_clipping=None,
                 train_step_counter=None,
                 debug_summaries=False,
                 name="SacAlgorithm"):
        """Create a SacAlgorithm

        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            actor_network (Network): The network will be called with
                call(observation, step_type).
            critic_network (Network): The network will be called with
                call(observation, action, step_type).
            critic_loss (None|OneStepTDLoss): an object for calculating critic loss.
                If None, a default OneStepTDLoss will be used.
            initial_log_alpha (float): initial value for variable log_alpha
            target_entropy (float|None): The target average policy entropy, for updating alpha.
            target_update_tau (float): Factor for soft update of the target
                networks.
            target_update_period (int): Period for soft update of the target
                networks.
            dqda_clipping (float): when computing the actor loss, clips the
                gradient dqda element-wise between [-dqda_clipping, dqda_clipping].
                Does not perform clipping if dqda_clipping == 0.
            actor_optimizer (tf.optimizers.Optimizer): The optimizer for actor.
            critic_optimizer (tf.optimizers.Optimizer): The optimizer for critic.
            alpha_optimizer (tf.optimizers.Optimizer): The optimizer for alpha.
            gradient_clipping (float): Norm length to clip gradients.
            train_step_counter (tf.Variable): An optional counter to increment
                every time the a new iteration is started. If None, it will use
                tf.summary.experimental.get_step(). If this is still None, a
                counter will be created.
            debug_summaries (bool): True if debug summaries should be created.
            name (str): The name of this algorithm.
        """
        critic_network1 = critic_network
        critic_network2 = critic_network.copy(name='CriticNetwork2')
        log_alpha = tfa_common.create_variable(name='log_alpha',
                                               initial_value=initial_log_alpha,
                                               dtype=tf.float32,
                                               trainable=True)
        super().__init__(
            action_spec,
            train_state_spec=SacState(
                share=SacShareState(actor=actor_network.state_spec),
                actor=SacActorState(critic1=critic_network.state_spec,
                                    critic2=critic_network.state_spec),
                critic=SacCriticState(
                    critic1=critic_network.state_spec,
                    critic2=critic_network.state_spec,
                    target_critic1=critic_network.state_spec,
                    target_critic2=critic_network.state_spec)),
            action_distribution_spec=actor_network.output_spec,
            predict_state_spec=actor_network.state_spec,
            optimizer=[actor_optimizer, critic_optimizer, alpha_optimizer],
            get_trainable_variables_func=[
                lambda: actor_network.trainable_variables, lambda:
                (critic_network1.trainable_variables + critic_network2.
                 trainable_variables), lambda: [log_alpha]
            ],
            gradient_clipping=gradient_clipping,
            train_step_counter=train_step_counter,
            debug_summaries=debug_summaries,
            name=name)

        self._log_alpha = log_alpha
        self._actor_network = actor_network
        self._critic_network1 = critic_network1
        self._critic_network2 = critic_network2
        self._target_critic_network1 = self._critic_network1.copy(
            name='TargetCriticNetwork1')
        self._target_critic_network2 = self._critic_network2.copy(
            name='TargetCriticNetwork2')
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._alpha_optimizer = alpha_optimizer

        if critic_loss is None:
            critic_loss = OneStepTDLoss(debug_summaries=debug_summaries)
        self._critic_loss = critic_loss

        flat_action_spec = tf.nest.flatten(self._action_spec)
        self._is_continuous = tensor_spec.is_continuous(flat_action_spec[0])
        if target_entropy is None:
            target_entropy = np.sum(
                list(
                    map(dist_utils.calc_default_target_entropy,
                        flat_action_spec)))
        self._target_entropy = target_entropy

        self._dqda_clipping = dqda_clipping

        self._update_target = common.get_target_updater(
            models=[self._critic_network1, self._critic_network2],
            target_models=[
                self._target_critic_network1, self._target_critic_network2
            ],
            tau=target_update_tau,
            period=target_update_period)

        tfa_common.soft_variables_update(
            self._critic_network1.variables,
            self._target_critic_network1.variables,
            tau=1.0)

        tfa_common.soft_variables_update(
            self._critic_network2.variables,
            self._target_critic_network2.variables,
            tau=1.0)
Exemplo n.º 12
0
 def update_delayed():
     return common.soft_variables_update(
         self._target_q_network_delayed.variables,
         self._target_q_network_delayed_2.variables,
         tau_delayed,
         tau_non_trainable=1.0)
Exemplo n.º 13
0
 def after_train(self, training_info):
     if self._predict_net:
         tfa_common.soft_variables_update(
             self._net.variables,
             self._predict_net.variables,
             tau=self._net_moving_average_rate)
Exemplo n.º 14
0
 def _initialize_v1(self):
   self._q_network.create_variables()
   if self._target_q_network:
     self._target_q_network.create_variables()
   return common.soft_variables_update(
       self._q_network.variables, self._target_q_network.variables, tau=1.0)
Exemplo n.º 15
0
    def __init__(self,
                 output_dim,
                 noise_dim=32,
                 input_tensor_spec=None,
                 hidden_layers=(256, ),
                 net: Network = None,
                 net_moving_average_rate=None,
                 entropy_regularization=0.,
                 kernel_sharpness=2.,
                 mi_weight=None,
                 mi_estimator_cls=MIEstimator,
                 optimizer: tf.optimizers.Optimizer = None,
                 name="Generator"):
        """Create a Generator.

        Args:
            output_dim (int): dimension of output
            noise_dim (int): dimension of noise
            input_tensor_spec (nested TensorSpec): spec of inputs. If there is
                no inputs, this should be None.
            hidden_layers (tuple): size of hidden layers.
            net (Network): network for generating outputs from [noise, inputs]
                or noise (if inputs is None). If None, a default one with
                hidden_layers will be created
            net_moving_average_rate (float): If provided, use a moving average
                version of net to do prediction. This has been shown to be
                effective for GAN training (arXiv:1907.02544, arXiv:1812.04948).
            entropy_regularization (float): weight of entropy regularization
            kernel_sharpness (float): Used only for entropy_regularization > 0.
                We calcualte the kernel in SVGD as:
                    exp(-kernel_sharpness * reduce_mean((x-y)^2/width)),
                where width is the elementwise moving average of (x-y)^2
            mi_estimator_cls (type): the class of mutual information estimator
                for maximizing the mutual information between [noise, inputs]
                and [outputs, inputs].
            optimizer (tf.optimizers.Optimizer): optimizer (optional)
            name (str): name of this generator
        """
        super().__init__(train_state_spec=(), optimizer=optimizer, name=name)
        self._noise_dim = noise_dim
        self._entropy_regularization = entropy_regularization
        if entropy_regularization == 0:
            self._grad_func = self._ml_grad
        else:
            self._grad_func = self._stein_grad
            self._kernel_width_averager = AdaptiveAverager(
                tensor_spec=tf.TensorSpec(shape=(output_dim, )))
            self._kernel_sharpness = kernel_sharpness

        noise_spec = tf.TensorSpec(shape=[noise_dim])

        if net is None:
            net_input_spec = noise_spec
            if input_tensor_spec is not None:
                net_input_spec = [net_input_spec, input_tensor_spec]
            net = EncodingNetwork(
                name="Generator",
                input_tensor_spec=net_input_spec,
                fc_layer_params=hidden_layers,
                last_layer_size=output_dim)

        self._mi_estimator = None
        self._input_tensor_spec = input_tensor_spec
        if mi_weight is not None:
            x_spec = noise_spec
            y_spec = tf.TensorSpec((output_dim, ))
            if input_tensor_spec is not None:
                x_spec = [x_spec, input_tensor_spec]
            self._mi_estimator = mi_estimator_cls(
                x_spec, y_spec, sampler='shift')
            self._mi_weight = mi_weight
        self._net = net
        self._predict_net = None
        self._net_moving_average_rate = net_moving_average_rate
        if net_moving_average_rate:
            self._predict_net = net.copy(name="Genrator_average")
            tfa_common.soft_variables_update(
                self._net.variables, self._predict_net.variables, tau=1.0)
Exemplo n.º 16
0
 def update():
   """Update target network."""
   critic_update = common.soft_variables_update(
     sc_net.variables,
     target_sc_net.variables, tau)
   return critic_update
Exemplo n.º 17
0
def train_eval(
        root_dir,
        random_seed=None,
        # Dataset params
        domain_name='cartpole',
        task_name='swingup',
        frame_shape=(84, 84, 3),
        image_aug_type='random_shifting',  # None/'random_shifting'
        frame_stack=3,
        action_repeat=4,
        # Params for learning
        num_env_steps=1000000,
        learn_ceb=True,
        use_augmented_q=False,
        # Params for CEB
        e_ctor=encoders.FRNConv,
        e_head_ctor=encoders.MVNormalDiagParamHead,
        b_ctor=encoders.FRNConv,
        b_head_ctor=encoders.MVNormalDiagParamHead,
        conv_feature_dim=50,  # deterministic feature used by actor/critic/ceb
        ceb_feature_dim=50,
        ceb_action_condition=True,
        ceb_backward_encode_rewards=True,
        initial_feature_step=0,
        feature_lr=3e-4,
        feature_lr_schedule=None,
        ceb_beta=0.01,
        ceb_beta_schedule=None,
        ceb_generative_ratio=0.0,
        ceb_generative_items=None,
        feature_grad_clip=None,
        enc_ema_tau=0.05,  # if enc_ema_tau=None, ceb also learns backend encoder
        use_critic_grad=True,
        # Params for SAC
        actor_kernel_init='glorot_uniform',
        normal_proj_net=sac_agent.sac_normal_projection_net,
        critic_kernel_init='glorot_uniform',
        critic_last_kernel_init='glorot_uniform',
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        collect_every=1,
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        batch_size=256,
        actor_learning_rate=3e-4,
        actor_lr_schedule=None,
        critic_learning_rate=3e-4,
        critic_lr_schedule=None,
        alpha_learning_rate=3e-4,
        alpha_lr_schedule=None,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        drivers_in_graph=True,
        # Params for eval
        num_eval_episodes=10,
        eval_env_interval=5000,  # number of env steps
        greedy_eval_policy=True,
        train_next_frame_decoder=False,
        # Params for summaries and logging
        baseline_log_fn=None,
        checkpoint_env_interval=100000,  # number of env steps
        log_env_interval=1000,  # number of env steps
        summary_interval=1000,
        image_summary_interval=0,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """train and eval for PI-SAC."""
    if random_seed is not None:
        tf.compat.v1.set_random_seed(random_seed)
        np.random.seed(random_seed)

    # Load baseline logs and write to tensorboard
    if baseline_log_fn is not None:
        baseline_log_fn(root_dir, domain_name, task_name, action_repeat)

    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    # Set iterations and intervals to be computed relative to the number of
    # environment steps rather than the number of gradient steps.
    num_iterations = (
        num_env_steps * collect_every // collect_steps_per_iteration +
        (initial_feature_step))
    checkpoint_interval = (checkpoint_env_interval * collect_every //
                           collect_steps_per_iteration)
    eval_interval = (eval_env_interval * collect_every //
                     collect_steps_per_iteration)
    log_interval = (log_env_interval * collect_every //
                    collect_steps_per_iteration)
    logging.info('num_env_steps = %d (env steps)', num_env_steps)
    logging.info('initial_feature_step = %d (gradient steps)',
                 initial_feature_step)
    logging.info('num_iterations = %d (gradient steps)', num_iterations)
    logging.info('checkpoint interval (env steps) = %d',
                 checkpoint_env_interval)
    logging.info('checkpoint interval (gradient steps) = %d',
                 checkpoint_interval)
    logging.info('eval interval (env steps) = %d', eval_env_interval)
    logging.info('eval interval (gradient steps) = %d', eval_interval)
    logging.info('log interval (env steps) = %d', log_env_interval)
    logging.info('log interval (gradient steps) = %d', log_interval)

    root_dir = os.path.expanduser(root_dir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    eval_histograms = [
        pisac_metric_utils.ReturnHistogram(buffer_size=num_eval_episodes),
    ]

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        pisac_metric_utils.ReturnStddevMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    # create training environment
    render_configs = {
        'height': frame_shape[0],
        'width': frame_shape[1],
        'camera_id': dict(quadruped=2).get(domain_name, 0),
    }

    tf_env = tf_py_environment.TFPyEnvironment(
        env_load_fn(domain_name, task_name, render_configs, frame_stack,
                    action_repeat))
    eval_tf_env = tf_py_environment.TFPyEnvironment(
        env_load_fn(domain_name, task_name, render_configs, frame_stack,
                    action_repeat))

    # Define global step
    g_step = common.create_variable('g_step')

    # Spec
    ims_shape = frame_shape[:2] + (frame_shape[2] * frame_stack, )
    ims_spec = tf.TensorSpec(shape=ims_shape, dtype=tf.uint8)
    conv_feature_spec = tf.TensorSpec(shape=(conv_feature_dim, ),
                                      dtype=tf.float32)
    action_spec = tf_env.action_spec()

    # Forward encoder
    e_enc = e_ctor(ims_spec, output_dim=conv_feature_dim, name='e')
    e_enc_t = e_ctor(ims_spec, output_dim=conv_feature_dim, name='e_t')
    e_enc.create_variables()
    e_enc_t.create_variables()
    common.soft_variables_update(e_enc.variables,
                                 e_enc_t.variables,
                                 tau=1.0,
                                 tau_non_trainable=1.0)

    # Forward encoder head
    if e_head_ctor is None:
        e_head = None
    else:
        stacked_action_spec = tensor_spec.BoundedTensorSpec(
            action_spec.shape[:-1] + (action_spec.shape[-1] * frame_stack),
            action_spec.dtype,
            action_spec.minimum.tolist() * frame_stack,
            action_spec.maximum.tolist() * frame_stack, action_spec.name)
        e_head_spec = [conv_feature_spec, stacked_action_spec
                       ] if ceb_action_condition else conv_feature_spec
        e_head = e_head_ctor(e_head_spec,
                             output_dim=ceb_feature_dim,
                             name='e_head')
        e_head.create_variables()

    # Backward encoder
    b_enc = b_ctor(ims_spec, output_dim=conv_feature_dim, name='b')
    b_enc.create_variables()

    # Backward encoder head
    if b_head_ctor is None:
        b_head = None
    else:
        stacked_reward_spec = tf.TensorSpec(shape=(frame_stack, ),
                                            dtype=tf.float32)
        b_head_spec = [conv_feature_spec, stacked_reward_spec
                       ] if ceb_backward_encode_rewards else conv_feature_spec
        b_head = b_head_ctor(b_head_spec,
                             output_dim=ceb_feature_dim,
                             name='b_head')
        b_head.create_variables()

    # future decoder for generative formulation
    future_deconv = None
    future_reward_mlp = None
    y_decoders = None
    if ceb_generative_ratio > 0.0:
        future_deconv = utils.SimpleDeconv(conv_feature_spec,
                                           output_tensor_spec=ims_spec)
        future_deconv.create_variables()

        future_reward_mlp = utils.MLP(conv_feature_spec,
                                      hidden_dims=(ceb_feature_dim,
                                                   ceb_feature_dim // 2,
                                                   frame_stack))
        future_reward_mlp.create_variables()

        y_decoders = [future_deconv, future_reward_mlp]

    m_vars = e_enc.trainable_variables
    if enc_ema_tau is None:
        m_vars += b_enc.trainable_variables
    else:  # do not train b_enc
        common.soft_variables_update(e_enc.variables,
                                     b_enc.variables,
                                     tau=1.0,
                                     tau_non_trainable=1.0)

    if e_head_ctor is not None:
        m_vars += e_head.trainable_variables
    if b_head_ctor is not None:
        m_vars += b_head.trainable_variables
    if ceb_generative_ratio > 0.0:
        m_vars += future_deconv.trainable_variables
        m_vars += future_reward_mlp.trainable_variables

    feature_lr_fn = schedule_utils.get_schedule_fn(base=feature_lr,
                                                   sched=feature_lr_schedule,
                                                   step=g_step)
    m_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=feature_lr_fn)

    # CEB beta schedule, e.q. 'berp@0:1.0:1000_10000:0.3:0'
    beta_fn = schedule_utils.get_schedule_fn(base=ceb_beta,
                                             sched=ceb_beta_schedule,
                                             step=g_step)

    def img_pred_summary_fn(obs, pred):
        utils.replay_summary('y0',
                             g_step,
                             reshape=True,
                             frame_stack=frame_stack,
                             image_summary_interval=image_summary_interval)(
                                 obs, None)
        utils.replay_summary('y0_pred',
                             g_step,
                             reshape=True,
                             frame_stack=frame_stack,
                             image_summary_interval=image_summary_interval)(
                                 pred, None)
        utils.replay_summary('y0_pred_diff',
                             g_step,
                             reshape=True,
                             frame_stack=frame_stack,
                             image_summary_interval=image_summary_interval)(
                                 ((obs - pred) / 2.0 + 0.5), None)

    ceb = ceb_task.CEB(beta_fn=beta_fn,
                       generative_ratio=ceb_generative_ratio,
                       generative_items=ceb_generative_items,
                       step_counter=g_step,
                       img_pred_summary_fn=img_pred_summary_fn)
    m_ceb = ceb_task.CEBTask(
        ceb,
        e_enc,
        b_enc,
        forward_head=e_head,
        backward_head=b_head,
        y_decoders=y_decoders,
        learn_backward_enc=(enc_ema_tau is None),
        action_condition=ceb_action_condition,
        backward_encode_rewards=ceb_backward_encode_rewards,
        optimizer=m_optimizer,
        grad_clip=feature_grad_clip,
        global_step=g_step)

    if train_next_frame_decoder:
        ns_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3)
        next_frame_deconv = utils.SimpleDeconv(conv_feature_spec,
                                               output_tensor_spec=ims_spec)
        next_frame_decoder = utils.PixelDecoder(
            next_frame_deconv,
            optimizer=ns_optimizer,
            step_counter=g_step,
            image_summary_interval=image_summary_interval,
            frame_stack=frame_stack)
        next_frame_deconv.create_variables()

    # Agent training
    actor_lr_fn = schedule_utils.get_schedule_fn(base=actor_learning_rate,
                                                 sched=actor_lr_schedule,
                                                 step=g_step)
    critic_lr_fn = schedule_utils.get_schedule_fn(base=critic_learning_rate,
                                                  sched=critic_lr_schedule,
                                                  step=g_step)
    alpha_lr_fn = schedule_utils.get_schedule_fn(base=alpha_learning_rate,
                                                 sched=alpha_lr_schedule,
                                                 step=g_step)

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        conv_feature_spec,
        action_spec,
        kernel_initializer=actor_kernel_init,
        fc_layer_params=actor_fc_layers,
        activation_fn=tf.keras.activations.relu,
        continuous_projection_net=normal_proj_net)

    critic_net = critic_network.CriticNetwork(
        (conv_feature_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        activation_fn=tf.nn.relu,
        kernel_initializer=critic_kernel_init,
        last_kernel_initializer=critic_last_kernel_init)

    tf_agent = sac_agent.SacAgent(
        ts.time_step_spec(observation_spec=conv_feature_spec),
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_lr_fn),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_lr_fn),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_lr_fn),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=g_step)
    tf_agent.initialize()

    env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
    average_return = tf_metrics.AverageReturnMetric(
        prefix='Train',
        buffer_size=num_eval_episodes,
        batch_size=tf_env.batch_size)
    train_metrics = [
        tf_metrics.NumberOfEpisodes(prefix='Train'), env_steps, average_return,
        tf_metrics.AverageEpisodeLengthMetric(prefix='Train',
                                              buffer_size=num_eval_episodes,
                                              batch_size=tf_env.batch_size),
        tf_metrics.AverageReturnMetric(name='LatestReturn',
                                       prefix='Train',
                                       buffer_size=1,
                                       batch_size=tf_env.batch_size)
    ]

    # Collect and eval policies
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), action_spec)

    eval_policy = tf_agent.policy
    if greedy_eval_policy:
        eval_policy = greedy_policy.GreedyPolicy(eval_policy)

    def obs_to_feature(observation):
        feature, _ = e_enc(observation['pixels'], training=False)
        return tf.stop_gradient(feature)

    eval_policy = FeaturePolicy(policy=eval_policy,
                                time_step_spec=tf_env.time_step_spec(),
                                obs_to_feature_fn=obs_to_feature)

    collect_policy = FeaturePolicy(policy=tf_agent.collect_policy,
                                   time_step_spec=tf_env.time_step_spec(),
                                   obs_to_feature_fn=obs_to_feature)

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=collect_policy.trajectory_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    # Checkpoints
    train_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'train'),
        agent=tf_agent,
        actor_net=actor_net,
        critic_net=critic_net,
        global_step=g_step,
        metrics=tfa_metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    train_checkpointer.initialize_or_restore()

    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        root_dir, 'policy'),
                                              policy=eval_policy,
                                              global_step=g_step)
    policy_checkpointer.initialize_or_restore()

    rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        root_dir, 'replay_buffer'),
                                          max_to_keep=1,
                                          replay_buffer=replay_buffer,
                                          global_step=g_step)
    rb_checkpointer.initialize_or_restore()

    if learn_ceb:
        d = dict()
        if future_deconv is not None:
            d.update(future_deconv=future_deconv)
        if future_reward_mlp is not None:
            d.update(future_reward_mlp=future_reward_mlp)
        model_ckpt = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'model'),
                                         forward_encoder=e_enc,
                                         forward_encoder_target=e_enc_t,
                                         forward_head=e_head,
                                         backward_encoder=b_enc,
                                         backward_head=b_head,
                                         global_step=g_step,
                                         **d)
    else:
        model_ckpt = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'model'),
                                         forward_encoder=e_enc,
                                         forward_encoder_target=e_enc_t,
                                         global_step=g_step)
    model_ckpt.initialize_or_restore()

    if train_next_frame_decoder:
        next_frame_decoder_ckpt = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'next_frame_decoder'),
            next_frame_decoder=next_frame_decoder,
            next_frame_deconv=next_frame_deconv,
            global_step=g_step)
        next_frame_decoder_ckpt.initialize_or_restore()

    if use_tf_functions and not drivers_in_graph:
        collect_policy.action = common.function(collect_policy.action)

    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps)
    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions and drivers_in_graph:
        initial_collect_driver.run = common.function(
            initial_collect_driver.run)
        collect_driver.run = common.function(collect_driver.run)

    # Collect initial replay data.
    if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
        qj(initial_collect_steps,
           'Initializing replay buffer by collecting random experience',
           tic=1)
        initial_collect_driver.run()
        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=env_steps.result())
        qj(s='Done initializing replay buffer', toc=1)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    time_acc = 0
    env_steps_before = env_steps.result().numpy()

    paddings = tf.constant([[4, 4], [4, 4], [0, 0]])

    def random_shifting(traj, meta):
        x0 = traj.observation['pixels'][0]
        x1 = traj.observation['pixels'][1]
        y0 = traj.observation['pixels'][frame_stack]
        y1 = traj.observation['pixels'][frame_stack + 1]
        x0 = tf.pad(x0, paddings, 'SYMMETRIC')
        x1 = tf.pad(x1, paddings, 'SYMMETRIC')
        y0 = tf.pad(y0, paddings, 'SYMMETRIC')
        y1 = tf.pad(y1, paddings, 'SYMMETRIC')
        x0a = tf.image.random_crop(x0, ims_shape)
        x1a = tf.image.random_crop(x1, ims_shape)
        x0 = tf.image.random_crop(x0, ims_shape)
        x1 = tf.image.random_crop(x1, ims_shape)
        y0 = tf.image.random_crop(y0, ims_shape)
        y1 = tf.image.random_crop(y1, ims_shape)
        return (traj, (x0, x1, x0a, x1a, y0, y1)), meta

    # Dataset generates trajectories with shape [B, T, ...]
    num_steps = frame_stack + 2
    with tf.device('/cpu:0'):
        if image_aug_type == 'random_shifting':
            dataset = replay_buffer.as_dataset(
                sample_batch_size=batch_size,
                num_steps=num_steps).unbatch().filter(
                    utils.filter_invalid_transition).map(
                        random_shifting,
                        num_parallel_calls=3).batch(batch_size).map(
                            utils.replay_summary(
                                'replay/filtered',
                                order_frame_stack=True,
                                frame_stack=frame_stack,
                                image_summary_interval=image_summary_interval,
                                has_augmentations=True))
        elif image_aug_type is None:
            dataset = replay_buffer.as_dataset(
                sample_batch_size=batch_size,
                num_steps=num_steps).unbatch().filter(
                    utils.filter_invalid_transition).batch(batch_size).map(
                        utils.replay_summary(
                            'replay/filtered',
                            order_frame_stack=True,
                            frame_stack=frame_stack,
                            image_summary_interval=image_summary_interval,
                            has_augmentations=False))
        else:
            raise NotImplementedError
    iterator_nstep = iter(dataset)

    def model_train_step(experience):
        if image_aug_type == 'random_shifting':
            experience, cropped_frames = experience
            x0, x1, _, _, y0, y1 = cropped_frames
            r0, r1, a0, a1 = utils.split_xy(experience,
                                            frame_stack,
                                            rewards_n_actions_only=True)
            x0 = x0[:, None, ...]
            x1 = x1[:, None, ...]
            y0 = y0[:, None, ...]
            y1 = y1[:, None, ...]
        elif image_aug_type is None:
            x0, x1, y0, y1, r0, r1, a0, a1 = utils.split_xy(
                experience, frame_stack, rewards_n_actions_only=False)
        else:
            raise NotImplementedError

        # Flatten stacked actions
        action_shape = a0.shape.as_list()
        a0 = tf.reshape(a0, [action_shape[0], action_shape[1], -1])
        a1 = tf.reshape(a1, [action_shape[0], action_shape[1], -1])

        if image_summary_interval > 0:
            utils.replay_summary(
                'ceb/x0',
                g_step,
                reshape=True,
                frame_stack=frame_stack,
                image_summary_interval=image_summary_interval)(x0, None)
            utils.replay_summary(
                'ceb/x1',
                g_step,
                reshape=True,
                frame_stack=frame_stack,
                image_summary_interval=image_summary_interval)(x1, None)
            utils.replay_summary(
                'ceb/y0',
                g_step,
                reshape=True,
                frame_stack=frame_stack,
                image_summary_interval=image_summary_interval)(y0, None)
            utils.replay_summary(
                'ceb/y1',
                g_step,
                reshape=True,
                frame_stack=frame_stack,
                image_summary_interval=image_summary_interval)(y1, None)

        ceb_loss, feat_x0, zx0 = m_ceb.train(x0, a0, y0, y1, r0, r1, m_vars)
        if train_next_frame_decoder:
            # zx0: [B, 1, Z]
            zx0 = tf.squeeze(zx0, axis=1)
            # y0: [B, 1, H, W, Cxframe_stack]
            next_obs = tf.cast(tf.squeeze(y0, axis=1), tf.float32) / 255.0
            next_frame_decoder.train(next_obs, tf.stop_gradient(zx0))

        if enc_ema_tau is not None:
            common.soft_variables_update(e_enc.variables,
                                         b_enc.variables,
                                         tau=enc_ema_tau,
                                         tau_non_trainable=enc_ema_tau)

    def agent_train_step(experience):
        # preprocess experience
        if image_aug_type == 'random_shifting':
            experience, cropped_frames = experience
            x0, x1, x0a, x1a, y0, y1 = cropped_frames
            experience = tf.nest.map_structure(
                lambda t: composite.slice_to(t, axis=1, end=2), experience)
            time_steps, actions, next_time_steps = (
                tf_agent.experience_to_transitions(experience))  # pylint: disable=protected-access
        elif image_aug_type is None:
            experience = tf.nest.map_structure(
                lambda t: composite.slice_to(t, axis=1, end=2), experience)
            time_steps, actions, next_time_steps = (
                tf_agent.experience_to_transitions(experience))  # pylint: disable=protected-access
            x0 = time_steps.observation['pixels']
            x1 = next_time_steps.observation['pixels']
        else:
            raise NotImplementedError

        tf_agent.train_pix(time_steps,
                           actions,
                           next_time_steps,
                           x0,
                           x1,
                           x0a=x0a if use_augmented_q else None,
                           x1a=x1a if use_augmented_q else None,
                           e_enc=e_enc,
                           e_enc_t=e_enc_t,
                           q_aug=use_augmented_q,
                           use_critic_grad=use_critic_grad)

    def checkpoint(step):
        rb_checkpointer.save(global_step=step)
        train_checkpointer.save(global_step=step)
        policy_checkpointer.save(global_step=step)
        model_ckpt.save(global_step=step)
        if train_next_frame_decoder:
            next_frame_decoder_ckpt.save(global_step=step)

    def evaluate():
        # Override outer record_if that may be out of sync with respect to the
        # env_steps.result() value used for the summay step.
        with tf.compat.v2.summary.record_if(True):
            qj(g_step.numpy(), 'Starting eval at step', tic=1)
            results = pisac_metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                histograms=eval_histograms,
                num_episodes=num_eval_episodes,
                train_step=env_steps.result(),
                summary_writer=summary_writer,
                summary_prefix='Eval',
                use_function=drivers_in_graph,
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, env_steps.result())
            tfa_metric_utils.log_metrics(eval_metrics)
            qj(s='Finished eval', toc=1)

    def update_target():
        common.soft_variables_update(
            e_enc.variables,
            e_enc_t.variables,
            tau=tf_agent.target_update_tau,
            tau_non_trainable=tf_agent.target_update_tau)
        common.soft_variables_update(
            tf_agent._critic_network_1.variables,  # pylint: disable=protected-access
            tf_agent._target_critic_network_1.variables,  # pylint: disable=protected-access
            tau=tf_agent.target_update_tau,
            tau_non_trainable=tf_agent.target_update_tau)
        common.soft_variables_update(
            tf_agent._critic_network_2.variables,  # pylint: disable=protected-access
            tf_agent._target_critic_network_2.variables,  # pylint: disable=protected-access
            tau=tf_agent.target_update_tau,
            tau_non_trainable=tf_agent.target_update_tau)

    if use_tf_functions:
        if learn_ceb:
            m_ceb.train = common.function(m_ceb.train)
            model_train_step = common.function(model_train_step)
        agent_train_step = common.function(agent_train_step)
        tf_agent.train_pix = common.function(tf_agent.train_pix)
        update_target = common.function(update_target)
        if train_next_frame_decoder:
            next_frame_decoder.train = common.function(
                next_frame_decoder.train)

    if not learn_ceb and initial_feature_step > 0:
        raise ValueError('Not learning CEB but initial_feature_step > 0')

    with tf.summary.record_if(
            lambda: tf.math.equal(g_step % summary_interval, 0)):
        if learn_ceb and g_step.numpy() < initial_feature_step:
            qj(initial_feature_step, 'Pretraining CEB...', tic=1)
            for _ in range(g_step.numpy(), initial_feature_step):
                with tf.name_scope('LearningRates'):
                    tf.summary.scalar(name='CEB learning rate',
                                      data=feature_lr_fn(),
                                      step=g_step)
                experience, _ = next(iterator_nstep)
                model_train_step(experience)
                g_step.assign_add(1)
            qj(s='Done pretraining CEB.', toc=1)

    first_step = True
    for _ in range(g_step.numpy(), num_iterations):
        g_step_val = g_step.numpy()
        start_time = time.time()

        with tf.summary.record_if(
                lambda: tf.math.equal(g_step % summary_interval, 0)):

            with tf.name_scope('LearningRates'):
                tf.summary.scalar(name='Actor learning rate',
                                  data=actor_lr_fn(),
                                  step=g_step)
                tf.summary.scalar(name='Critic learning rate',
                                  data=critic_lr_fn(),
                                  step=g_step)
                tf.summary.scalar(name='Alpha learning rate',
                                  data=alpha_lr_fn(),
                                  step=g_step)
                if learn_ceb:
                    tf.summary.scalar(name='CEB learning rate',
                                      data=feature_lr_fn(),
                                      step=g_step)

            with tf.name_scope('Train'):
                tf.summary.scalar(name='StepsVsEnvironmentSteps',
                                  data=env_steps.result(),
                                  step=g_step)
                tf.summary.scalar(name='StepsVsAverageReturn',
                                  data=average_return.result(),
                                  step=g_step)

            if g_step_val % collect_every == 0:
                time_step, policy_state = collect_driver.run(
                    time_step=time_step,
                    policy_state=policy_state,
                )

            experience, _ = next(iterator_nstep)
            agent_train_step(experience)
            if (g_step_val -
                    initial_feature_step) % tf_agent.target_update_period == 0:
                update_target()
            if learn_ceb:
                model_train_step(experience)
            time_acc += time.time() - start_time

        # Increment global step counter.
        g_step.assign_add(1)
        g_step_val = g_step.numpy()

        if (g_step_val - initial_feature_step) % log_interval == 0:
            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=env_steps.result())
            logging.info('env steps = %d, average return = %f',
                         env_steps.result(), average_return.result())
            env_steps_per_sec = (env_steps.result().numpy() -
                                 env_steps_before) / time_acc
            logging.info('%.3f env steps/sec', env_steps_per_sec)
            tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                        data=env_steps_per_sec,
                                        step=env_steps.result())
            time_acc = 0
            env_steps_before = env_steps.result().numpy()

        if (g_step_val - initial_feature_step) % eval_interval == 0:
            eval_start_time = time.time()
            evaluate()
            logging.info('eval time %.3f sec', time.time() - eval_start_time)

        if (g_step_val - initial_feature_step) % checkpoint_interval == 0:
            checkpoint(g_step_val)

        # Write gin config to Tensorboard
        if first_step:
            summ = utils.Summ(0, root_dir)
            conf = gin.operative_config_str()
            conf = '    ' + conf.replace('\n', '\n    ')
            summ.text('gin/config', conf)
            summ.flush()
            first_step = False

    # Final checkpoint.
    checkpoint(g_step.numpy())

    # Final evaluation.
    evaluate()
Exemplo n.º 18
0
    def model_train_step(experience):
        if image_aug_type == 'random_shifting':
            experience, cropped_frames = experience
            x0, x1, _, _, y0, y1 = cropped_frames
            r0, r1, a0, a1 = utils.split_xy(experience,
                                            frame_stack,
                                            rewards_n_actions_only=True)
            x0 = x0[:, None, ...]
            x1 = x1[:, None, ...]
            y0 = y0[:, None, ...]
            y1 = y1[:, None, ...]
        elif image_aug_type is None:
            x0, x1, y0, y1, r0, r1, a0, a1 = utils.split_xy(
                experience, frame_stack, rewards_n_actions_only=False)
        else:
            raise NotImplementedError

        # Flatten stacked actions
        action_shape = a0.shape.as_list()
        a0 = tf.reshape(a0, [action_shape[0], action_shape[1], -1])
        a1 = tf.reshape(a1, [action_shape[0], action_shape[1], -1])

        if image_summary_interval > 0:
            utils.replay_summary(
                'ceb/x0',
                g_step,
                reshape=True,
                frame_stack=frame_stack,
                image_summary_interval=image_summary_interval)(x0, None)
            utils.replay_summary(
                'ceb/x1',
                g_step,
                reshape=True,
                frame_stack=frame_stack,
                image_summary_interval=image_summary_interval)(x1, None)
            utils.replay_summary(
                'ceb/y0',
                g_step,
                reshape=True,
                frame_stack=frame_stack,
                image_summary_interval=image_summary_interval)(y0, None)
            utils.replay_summary(
                'ceb/y1',
                g_step,
                reshape=True,
                frame_stack=frame_stack,
                image_summary_interval=image_summary_interval)(y1, None)

        ceb_loss, feat_x0, zx0 = m_ceb.train(x0, a0, y0, y1, r0, r1, m_vars)
        if train_next_frame_decoder:
            # zx0: [B, 1, Z]
            zx0 = tf.squeeze(zx0, axis=1)
            # y0: [B, 1, H, W, Cxframe_stack]
            next_obs = tf.cast(tf.squeeze(y0, axis=1), tf.float32) / 255.0
            next_frame_decoder.train(next_obs, tf.stop_gradient(zx0))

        if enc_ema_tau is not None:
            common.soft_variables_update(e_enc.variables,
                                         b_enc.variables,
                                         tau=enc_ema_tau,
                                         tau_non_trainable=enc_ema_tau)
Exemplo n.º 19
0
    def __init__(self,
                 action_spec,
                 actor_network: Network,
                 critic_network: Network,
                 ou_stddev=0.2,
                 ou_damping=0.15,
                 critic_loss=None,
                 target_update_tau=0.05,
                 target_update_period=1,
                 dqda_clipping=None,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 gradient_clipping=None,
                 train_step_counter=None,
                 debug_summaries=False,
                 name="DdpgAlgorithm"):
        """
        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            actor_network (Network):  The network will be called with
                call(observation, step_type).
            critic_network (Network): The network will be called with
                call(observation, action, step_type).
            ou_stddev (float): Standard deviation for the Ornstein-Uhlenbeck
                (OU) noise added in the default collect policy.
            ou_damping (float): Damping factor for the OU noise added in the
                default collect policy.
            critic_loss (None|OneStepTDLoss): an object for calculating critic
                loss. If None, a default OneStepTDLoss will be used.
            target_update_tau (float): Factor for soft update of the target
                networks.
            target_update_period (int): Period for soft update of the target
                networks.
            dqda_clipping (float): when computing the actor loss, clips the
                gradient dqda element-wise between [-dqda_clipping, dqda_clipping].
                Does not perform clipping if dqda_clipping == 0.
            actor_optimizer (tf.optimizers.Optimizer): The optimizer for actor.
            critic_optimizer (tf.optimizers.Optimizer): The optimizer for actor.
            gradient_clipping (float): Norm length to clip gradients.
            train_step_counter (tf.Variable): An optional counter to increment
                every time the a new iteration is started. If None, it will use
                tf.summary.experimental.get_step(). If this is still None, a
                counter will be created.
            debug_summaries (bool): True if debug summaries should be created.
            name (str): The name of this algorithm.
        """
        train_state_spec = DdpgState(
            actor=DdpgActorState(actor=actor_network.state_spec,
                                 critic=critic_network.state_spec),
            critic=DdpgCriticState(critic=critic_network.state_spec,
                                   target_actor=actor_network.state_spec,
                                   target_critic=critic_network.state_spec))

        super().__init__(action_spec,
                         train_state_spec=train_state_spec,
                         action_distribution_spec=action_spec,
                         optimizer=[actor_optimizer, critic_optimizer],
                         get_trainable_variables_func=[
                             lambda: actor_network.trainable_variables,
                             lambda: critic_network.trainable_variables
                         ],
                         gradient_clipping=gradient_clipping,
                         train_step_counter=train_step_counter,
                         debug_summaries=debug_summaries,
                         name=name)

        self._actor_network = actor_network
        self._critic_network = critic_network
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer

        self._target_actor_network = actor_network.copy(
            name='target_actor_network')
        self._target_critic_network = critic_network.copy(
            name='target_critic_network')

        self._ou_stddev = ou_stddev
        self._ou_damping = ou_damping

        if critic_loss is None:
            critic_loss = OneStepTDLoss(debug_summaries=debug_summaries)
        self._critic_loss = critic_loss

        self._ou_process = self._create_ou_process(ou_stddev, ou_damping)

        self._update_target = common.get_target_updater(
            models=[self._actor_network, self._critic_network],
            target_models=[
                self._target_actor_network, self._target_critic_network
            ],
            tau=target_update_tau,
            period=target_update_period)

        self._dqda_clipping = dqda_clipping

        tfa_common.soft_variables_update(self._critic_network.variables,
                                         self._target_critic_network.variables,
                                         tau=1.0)
        tfa_common.soft_variables_update(self._actor_network.variables,
                                         self._target_actor_network.variables,
                                         tau=1.0)
Exemplo n.º 20
0
 def update():
   return common_utils.soft_variables_update(
       self._q_network.variables, self._target_q_network.variables, tau)
Exemplo n.º 21
0
 def update():
   return common.soft_variables_update(
       self._q_network.variables,
       self._target_q_network.variables,
       tau,
       tau_non_trainable=1.0)
Exemplo n.º 22
0
 def update():
     return common.soft_variables_update(
         self.QMIXNet.variables,
         self.TargetQMIXNet.variables,
         tau,
         tau_non_trainable=1.0)
Exemplo n.º 23
0
 def _initialize(self):
   common.soft_variables_update(
       self._q_network.variables, self._target_q_network.variables, tau=1.0)
Exemplo n.º 24
0
 def update():
     return tfagents_common.soft_variables_update(
         self._value_network.variables,
         self._target_network.variables,
         tau,
         tau_non_trainable=1.0)
Exemplo n.º 25
0
 def _initialize(self):
     tfagents_common.soft_variables_update(self._value_network.variables,
                                           self._target_network.variables,
                                           tau=1.0)
Exemplo n.º 26
0
 def update():
   """Update target network."""
   critic_update = common.soft_variables_update(
       self._safety_critic_network.variables,
       self._target_safety_critic_network.variables, tau)
   return critic_update