Exemple #1
0
    def __init__(
            self,
            state_shape,
            units=(32, 32),
            lr=0.001,
            enable_sn=False,
            name="GAIfO",
            **kwargs):
        """
        Initialize GAIfO

        Args:
            state_shape (iterable of int):
            action_dim (int):
            units (iterable of int): The default is ``(32, 32)``
            lr (float): Learning rate. The default is ``0.001``
            enable_sn (bool): Whether enable Spectral Normalization. The defailt is ``False``
            name (str): The default is ``"GAIfO"``
        """
        IRLPolicy.__init__(self, name=name, n_training=1, **kwargs)
        self.disc = Discriminator(
            state_shape=state_shape,
            units=units, enable_sn=enable_sn)
        self.optimizer = tf.keras.optimizers.Adam(
            learning_rate=lr, beta_1=0.5)
Exemple #2
0
 def __init__(
         self,
         state_shape,
         action_dim,
         units=[32, 32],
         n_latent_unit=32,
         lr=5e-5,
         kl_target=0.5,
         reg_param=0.,
         enable_sn=False,
         enable_gp=False,
         name="VAIL",
         **kwargs):
     """
     :param enable_sn (bool): If true, add spectral normalization in Dense layer
     :param enable_gp (bool): If true, add gradient penalty to loss function
     """
     IRLPolicy.__init__(
         self, name=name, n_training=10, **kwargs)
     self.disc = Discriminator(
         state_shape=state_shape, action_dim=action_dim,
         units=units, n_latent_unit=n_latent_unit,
         enable_sn=enable_sn)
     self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
     self._kl_target = kl_target
     self._reg_param = tf.Variable(reg_param, dtype=tf.float32)
     self._step_reg_param = tf.constant(1e-5, dtype=tf.float32)
     self._enable_gp = enable_gp
Exemple #3
0
 def __init__(self,
              state_shape,
              units=[32, 32],
              lr=0.001,
              enable_sn=False,
              name="GAIfO",
              **kwargs):
     IRLPolicy.__init__(self, name=name, n_training=1, **kwargs)
     self.disc = Discriminator(state_shape=state_shape,
                               units=units,
                               enable_sn=enable_sn)
     self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.5)
Exemple #4
0
    def __init__(self,
                 state_shape,
                 action_dim,
                 units=(32, 32),
                 n_latent_unit=32,
                 lr=5e-5,
                 kl_target=0.5,
                 reg_param=0.,
                 enable_sn=False,
                 enable_gp=False,
                 name="VAIL",
                 **kwargs):
        """
        Initialize VAIL

        Args:
            state_shape (iterable of int):
            action_dim (int):
            units (iterable of int): The default is ``(32, 32)``
            lr (float): Learning rate. The default is ``5e-5``
            kl_target (float): The default is ``0.5``
            reg_param (float): The default is ``0``
            enable_sn (bool): Whether enable Spectral Normalization. The defailt is ``False``
            enable_gp (bool): Whether loss function includes gradient panalty
            name (str): The default is ``"VAIL"``
        """
        IRLPolicy.__init__(self, name=name, n_training=10, **kwargs)
        self.disc = Discriminator(state_shape=state_shape,
                                  action_dim=action_dim,
                                  units=units,
                                  n_latent_unit=n_latent_unit,
                                  enable_sn=enable_sn)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
        self._kl_target = kl_target
        self._reg_param = tf.Variable(reg_param, dtype=tf.float32)
        self._step_reg_param = tf.constant(1e-5, dtype=tf.float32)
        self._enable_gp = enable_gp
Exemple #5
0
 def get_argument(parser=None):
     parser = IRLPolicy.get_argument(parser)
     parser.add_argument('--enable-sn', action='store_true')
     return parser