Beispiel #1
0
    def test_type_kwargs(self):
        """The the special cases involving "type" and "kwargs"
        hyperparameters.
        """
        default_hparams = {"type": "type_name", "kwargs": {"arg1": "argv1"}}

        hparams = {"type": "type_name"}
        hparams_ = HParams(hparams, default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), default_hparams["kwargs"])

        hparams = {"type": "type_name", "kwargs": {"arg2": "argv2"}}
        hparams_ = HParams(hparams, default_hparams)
        full_kwargs = {}
        full_kwargs.update(default_hparams["kwargs"])
        full_kwargs.update(hparams["kwargs"])
        self.assertEqual(hparams_.kwargs.todict(), full_kwargs)

        hparams = {"kwargs": {"arg2": "argv2"}}
        hparams_ = HParams(hparams, default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), full_kwargs)

        hparams = {"type": "type_name2"}
        hparams_ = HParams(hparams, default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), {})

        hparams = {"type": "type_name2", "kwargs": {"arg3": "argv3"}}
        hparams_ = HParams(hparams, default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), hparams["kwargs"])
Beispiel #2
0
    def __init__(self, vocab, hparams=None):
        self._hparams = HParams(hparams, self.default_hparams())

        # Initialize embeddings
        init_fn_kwargs = self._hparams.init_fn.kwargs.todict()
        if "shape" in init_fn_kwargs or "size" in init_fn_kwargs:
            raise ValueError("Argument 'shape' or 'size' must not be "
                             "specified. They are inferred automatically.")
        init_fn = utils.get_function(
            self._hparams.init_fn.type,
            ["numpy.random", "numpy", "texar.tf.custom"])

        try:
            self._word_vecs = init_fn(size=[len(vocab), self._hparams.dim],
                                      **init_fn_kwargs)
        except TypeError:
            self._word_vecs = init_fn(shape=[len(vocab), self._hparams.dim],
                                      **init_fn_kwargs)

        # Optionally read embeddings from file
        if self._hparams.file is not None and self._hparams.file != "":
            read_fn = utils.get_function(self._hparams.read_fn, [
                "texar.tf.data.embedding", "texar.tf.data", "texar.tf.custom"
            ])

            self._word_vecs = \
                read_fn(self._hparams.file, vocab, self._word_vecs)
Beispiel #3
0
    def test_switch_dropout(self):
        """Tests dropout mode.
        """
        emb_dim = 4
        num_units = 64
        hparams = {
            "kwargs": {
                "num_units": num_units
            },
            "num_layers": 2,
            "dropout": {
                "input_keep_prob": 0.8,
            },
        }
        mode = tf.placeholder(tf.string)
        hparams_ = HParams(hparams, layers.default_rnn_cell_hparams())
        cell = layers.get_rnn_cell(hparams_, mode)

        batch_size = 16
        inputs = tf.zeros([batch_size, emb_dim], dtype=tf.float32)
        output, state = cell(inputs,
                             cell.zero_state(batch_size, dtype=tf.float32))
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            output_train, _ = sess.run(
                [output, state], feed_dict={mode: tf.estimator.ModeKeys.TRAIN})
            self.assertEqual(output_train.shape[0], batch_size)
            output_test, _ = sess.run(
                [output, state], feed_dict={mode: tf.estimator.ModeKeys.EVAL})
            self.assertEqual(output_test.shape[0], batch_size)
Beispiel #4
0
    def test_get_rnn_cell(self):
        """Tests :func:`texar.tf.core.layers.get_rnn_cell`.
        """
        emb_dim = 4
        num_units = 64

        # Given instance
        hparams = {"type": rnn.LSTMCell(num_units)}
        cell = layers.get_rnn_cell(hparams)
        self.assertTrue(isinstance(cell, rnn.LSTMCell))

        # Given class
        hparams = {"type": rnn.LSTMCell, "kwargs": {"num_units": 10}}
        cell = layers.get_rnn_cell(hparams)
        self.assertTrue(isinstance(cell, rnn.LSTMCell))

        # Given string, and complex hyperparameters
        keep_prob_x = tf.placeholder(name='keep_prob',
                                     shape=[],
                                     dtype=tf.float32)
        hparams = {
            "type": "tensorflow.contrib.rnn.GRUCell",
            "kwargs": {
                "num_units": num_units
            },
            "num_layers": 2,
            "dropout": {
                "input_keep_prob": 0.8,
                "state_keep_prob": keep_prob_x,
                "variational_recurrent": True,
                "input_size": [emb_dim, num_units]
            },
            "residual": True,
            "highway": True
        }

        hparams_ = HParams(hparams, layers.default_rnn_cell_hparams())
        cell = layers.get_rnn_cell(hparams_)

        batch_size = 16
        inputs = tf.zeros([batch_size, emb_dim], dtype=tf.float32)
        output, state = cell(inputs,
                             cell.zero_state(batch_size, dtype=tf.float32))
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            feed_dict = {
                keep_prob_x: 1.0,
                context.global_mode(): tf.estimator.ModeKeys.TRAIN
            }
            output_, state_ = sess.run([output, state], feed_dict=feed_dict)

            self.assertEqual(output_.shape[0], batch_size)
            if isinstance(state_, (list, tuple)):
                self.assertEqual(state_[0].shape[0], batch_size)
                self.assertEqual(state_[0].shape[1], hparams_.kwargs.num_units)
            else:
                self.assertEqual(state_.shape[0], batch_size)
                self.assertEqual(state_.shape[1], hparams_.kwargs.num_units)
Beispiel #5
0
 def __init__(self, hparams=None):
     self._hparams = HParams(hparams, self.default_hparams())
     self._template = tf.make_template(self._hparams.name,
                                       self._build,
                                       create_scope_now_=True)
     self._unique_name = self.variable_scope.name.split("/")[-1]
     self._trainable_variables = []
     self._built = False
Beispiel #6
0
def get_gradient_clip_fn(hparams=None):
    """Creates a gradient clipping function based on the hyperparameters.

    See the :attr:`gradient_clip` field in
    :meth:`~texar.tf.core.default_optimization_hparams` for all
    hyperparameters and default values.

    The gradient clipping function takes a list of `(gradients, variables)`
    tuples and returns a list of `(clipped_gradients, variables)` tuples.
    Typical examples include
    :tf_main:`tf.clip_by_global_norm <clip_by_global_norm>`,
    :tf_main:`tf.clip_by_value <clip_by_value>`,
    :tf_main:`tf.clip_by_norm <clip_by_norm>`,
    :tf_main:`tf.clip_by_average_norm <clip_by_average_norm>`, etc.

    Args:
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically.

    Returns:
        function or `None`: If hparams["type"] is specified, returns
        the respective function. If hparams["type"] is empty,
        returns `None`.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(
            hparams, default_optimization_hparams()["gradient_clip"])
    fn_type = hparams["type"]
    if fn_type is None or fn_type == "":
        return None

    fn_modules = ["tensorflow", "texar.tf.custom"]
    clip_fn = utils.get_function(fn_type, fn_modules)
    clip_fn_args = utils.get_args(clip_fn)
    fn_kwargs = hparams["kwargs"]
    if isinstance(fn_kwargs, HParams):
        fn_kwargs = fn_kwargs.todict()

    def grad_clip_fn(grads_and_vars):
        """Gradient clipping function.

        Args:
            grads_and_vars (list): A list of `(gradients, variables)` tuples.

        Returns:
            list: A list of `(clipped_gradients, variables)` tuples.
        """
        grads, vars_ = zip(*grads_and_vars)
        if clip_fn == tf.clip_by_global_norm:
            clipped_grads, _ = clip_fn(t_list=grads, **fn_kwargs)
        elif 't_list' in clip_fn_args:
            clipped_grads = clip_fn(t_list=grads, **fn_kwargs)
        elif 't' in clip_fn_args:     # e.g., tf.clip_by_value
            clipped_grads = [clip_fn(t=grad, **fn_kwargs) for grad in grads]

        return list(zip(clipped_grads, vars_))

    return grad_clip_fn
Beispiel #7
0
def get_optimizer_fn(hparams=None):
    """Returns a function `optimizer_fn` of making optimizer instance, along
    with the optimizer class.

    .. role:: python(code)
       :language: python

    The function has the signiture
    :python:`optimizer_fn(learning_rate=None) -> optimizer class instance`

    See the :attr:`"optimizer"` field of
    :meth:`~texar.tf.core.default_optimization_hparams` for all
    hyperparameters and default values.

    The optimizer class must be a subclass of
    :tf_main:`tf.train.Optimizer <train/Optimizer>`.

    Args:
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically.

    Returns:
        - If hparams["type"] is a string or optimizer class, returns\
        `(optimizer_fn, optimizer class)`,

        - If hparams["type"] is an optimizer instance, returns \
        `(the optimizer instance, optimizer class)`
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(
            hparams, default_optimization_hparams()["optimizer"])

    opt = hparams["type"]
    if isinstance(opt, tf.train.Optimizer):
        return opt, type(opt)
    opt_modules = ['tensorflow.train',
                   'tensorflow.contrib.opt',
                   'texar.tf.core.optimization',
                   'texar.tf.custom']
    try:
        opt_class = utils.check_or_get_class(opt, opt_modules,
                                             tf.train.Optimizer)
    except TypeError:
        raise ValueError(
            "Unrecognized optimizer. Must be string name of the "
            "optimizer class, or the class which is a subclass of "
            "tf.train.Optimizer, or an instance of the subclass of "
            "Optimizer.")

    def _get_opt(learning_rate=None):
        opt_kwargs = hparams["kwargs"].todict()
        fn_args = set(utils.get_args(opt_class.__init__))
        if 'learning_rate' in fn_args and learning_rate is not None:
            opt_kwargs["learning_rate"] = learning_rate
        return opt_class(**opt_kwargs)

    return _get_opt, opt_class
Beispiel #8
0
    def load_pretrained_config(self,
                               pretrained_model_name=None,
                               cache_dir=None,
                               hparams=None):
        r"""Load paths and configurations of the pre-trained model.

        Args:
            pretrained_model_name (optional): A str with the name
                of a pre-trained model to load. If `None`, will use the model
                name in :attr:`hparams`.
            cache_dir (optional): The path to a folder in which the
                pre-trained models will be cached. If `None` (default),
                a default directory will be used.
            hparams (dict or HParams, optional): Hyperparameters. Missing
                hyperparameter will be set to default values. See
                :meth:`default_hparams` for the hyperparameter structure
                and default values.
        """
        if not hasattr(self, "_hparams"):
            self._hparams = HParams(hparams, self.default_hparams())
        else:
            # Probably already parsed by subclasses. We rely on subclass
            # implementations to get this right.
            # As a sanity check, we require `hparams` to be `None` in this case.
            if hparams is not None:
                raise ValueError(
                    "`self._hparams` is already assigned, but `hparams` "
                    "argument is not None.")

        self.pretrained_model_dir = None
        self.pretrained_model_name = pretrained_model_name

        if self.pretrained_model_name is None:
            self.pretrained_model_name = self._hparams.pretrained_model_name
        if self.pretrained_model_name is not None:
            self.pretrained_model_dir = self.download_checkpoint(
                self.pretrained_model_name, cache_dir)
            pretrained_model_hparams = self._transform_config(
                self.pretrained_model_name, self.pretrained_model_dir)
            self._hparams = HParams(
                pretrained_model_hparams, self._hparams.todict())
Beispiel #9
0
    def __init__(self,
                 pretrained_model_name=None,
                 cache_dir=None,
                 hparams=None):
        PretrainedBase.__init__(self, pretrained_model_name, cache_dir,
                                hparams)
        if self.pretrained_model_dir:
            self._hparams = HParams(self.pretrained_model_hparams,
                                    self._hparams.todict())

        with tf.variable_scope(self.variable_scope):
            if self._hparams.initializer:
                tf.get_variable_scope().set_initializer(
                    layers.get_initializer(self._hparams.initializer))

            # Word embedding
            self.word_embedder = WordEmbedder(
                vocab_size=self._hparams.vocab_size,
                hparams=self._hparams.embed)

            # Segment embedding for each type of tokens
            self.segment_embedder = WordEmbedder(
                vocab_size=self._hparams.type_vocab_size,
                hparams=self._hparams.segment_embed)

            # Position embedding
            self.position_embedder = PositionEmbedder(
                position_size=self._hparams.position_size,
                hparams=self._hparams.position_embed)

            # The BERT encoder (a TransformerEncoder)
            self.encoder = TransformerEncoder(hparams=self._hparams.encoder)

            with tf.variable_scope("pooler"):
                kwargs_i = {
                    "units": self._hparams.hidden_size,
                    "activation": tf.tanh
                }
                layer_hparams = {"type": "Dense", "kwargs": kwargs_i}
                self.pooler = layers.get_layer(hparams=layer_hparams)
Beispiel #10
0
    def test_typecheck(self):
        """Tests type-check functionality.
        """
        def _foo():
            pass

        def _bar():
            pass

        default_hparams = {"fn": _foo, "fn_2": _foo}
        hparams = {"fn": _foo, "fn_2": _bar}
        hparams_ = HParams(hparams, default_hparams)
        self.assertEqual(hparams_.fn, default_hparams["fn"])
Beispiel #11
0
    def __init__(self, data_hparams, hparams=None):
        ModelBase.__init__(self, hparams)

        self._data_hparams = HParams(data_hparams,
                                     PairedTextData.default_hparams())

        self._src_vocab = None
        self._tgt_vocab = None
        self._src_embedder = None
        self._tgt_embedder = None
        self._connector = None
        self._encoder = None
        self._decoder = None
Beispiel #12
0
    def __init__(self, hparams):
        TextDataBase.__init__(self, hparams)
        # Defaultizes hparams of each dataset
        datasets_hparams = self._hparams.datasets
        defaultized_datasets_hparams = []
        for ds_hpms in datasets_hparams:
            data_type = ds_hpms.get("data_type", None)
            defaultized_ds_hpms = HParams(ds_hpms,
                                          _default_dataset_hparams(data_type))
            defaultized_datasets_hparams.append(defaultized_ds_hpms)
        self._hparams.datasets = defaultized_datasets_hparams

        with tf.name_scope(self.name, self.default_hparams()["name"]):
            self._make_data()
Beispiel #13
0
def _get_static_lr(learning_rate=None, optimizer_class=None, hparams=None):
    """Return the base static learning_rate.
        A helper function for creating the optimization function.
    """
    hparams = HParams(hparams, default_optimization_hparams())
    opt_hparams = hparams['optimizer']
    if learning_rate is None:
        learning_rate = opt_hparams["kwargs"].get("learning_rate", None)
    if learning_rate is None:
        # Try to get learning_rate from the default value of the
        # optimizer's argument
        opt_argspec = utils.get_default_arg_values(optimizer_class.__init__)
        learning_rate = opt_argspec.get("learning_rate", None)
    return learning_rate
Beispiel #14
0
def get_embedding(hparams=None,
                  init_value=None,
                  num_embeds=None,
                  variable_scope='Embedding'):
    r"""Creates embedding variable if not exists.

    Args:
        hparams (dict or HParams, optional): Embedding hyperparameters. Missing
            hyperparameters are set to default values. See
            :func:`~texar.tf.modules.default_embedding_hparams`
            for all hyperparameters and default values.

            If :attr:`init_value` is given, :attr:`hparams["initializer"]`,
            and :attr:`hparams["dim"]` are ignored.
        init_value (Tensor or numpy array, optional): Initial values of the
            embedding variable. If not given, embedding is initialized as
            specified in :attr:`hparams["initializer"]`.
        num_embeds (int, optional): The number of embedding items
            (e.g., vocabulary size). Required if :attr:`init_value` is
            not provided.
        variable_scope (str or VariableScope, optional): Variable scope of
            the embedding variable.

    Returns:
        Variable or Tensor: A 2D `Variable` or `Tensor` of the same shape with
        :attr:`init_value` or of the shape ``[num_embeds, hparams["dim"]]``.
    """
    with tf.variable_scope(variable_scope):
        if hparams is None or isinstance(hparams, dict):
            hparams = HParams(hparams, default_embedding_hparams())
        regularizer = layers.get_regularizer(hparams["regularizer"])
        if init_value is None:
            initializer = layers.get_initializer(hparams["initializer"])
            dim = hparams["dim"]
            if not isinstance(hparams["dim"], (list, tuple)):
                dim = [dim]
            embedding = tf.get_variable(name='w',
                                        shape=[num_embeds] + dim,
                                        initializer=initializer,
                                        regularizer=regularizer,
                                        trainable=hparams["trainable"])
        else:
            init_value = tf.cast(init_value, tf.float32)
            embedding = tf.get_variable(name='w',
                                        initializer=init_value,
                                        regularizer=regularizer,
                                        trainable=hparams["trainable"])

        return embedding
Beispiel #15
0
 def __init__(self, hparams=None):
     if not hasattr(self, '_hparams'):
         self._hparams = HParams(hparams, self.default_hparams())
     else:
         # Probably already parsed by subclasses. We rely on subclass
         # implementations to get this right.
         # As a sanity check, we require `hparams` to be `None` in this case.
         if hparams is not None:
             raise ValueError(
                 "`self._hparams` already exists. Argument `hparams` "
                 "must be set to `None` in this case.")
     self._template = tf.make_template(self._hparams.name,
                                       self._build,
                                       create_scope_now_=True)
     self._unique_name = self.variable_scope.name.split("/")[-1]
     self._trainable_variables = []
     self._built = False
Beispiel #16
0
def get_optimizer(learning_rate=None, global_step=None, hparams=None):

    """Creates a optimizer instance.

    Args:
        learning_rate (float or Tensor, optional): If `None`, learning rate
            specified in :attr:`hparams`, or the default learning rate
            of the optimizer (if exists) is used.
        global_step (optional): A scalar int Tensor. Step counter to update on
            each step unless :attr:`increment_global_step` is `False`.
            Learning rate decay uses :attr:`global_step`.
            If `None`, it will be fetched from the default graph (see
            :tf_main:`tf.train.get_global_step <train/get_global_step>` for
            more details). If it has not been created, no step will be
            incremented with each weight update.
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically. See
            :func:`~texar.tf.core.default_optimization_hparams` for
            all hyperparameters and default values.

    Returns:
        optimizer: the tf.train.Optimizer instance specified in hparams.
    """
    hparams = HParams(hparams, default_optimization_hparams())

    opt_hparams = hparams["optimizer"]
    optimizer_fn, optimizer_class = get_optimizer_fn(opt_hparams)

    static_lr = _get_static_lr(learning_rate, optimizer_class, hparams)

    lr_decay_fn = get_learning_rate_decay_fn(hparams["learning_rate_decay"])
    if lr_decay_fn is not None:
        learning_rate = lr_decay_fn(learning_rate=static_lr,
                                    global_step=global_step)
    else:
        learning_rate = static_lr

    tf.summary.scalar("learning_rate", learning_rate)

    optimizer = optimizer_fn(learning_rate=learning_rate)

    return optimizer
Beispiel #17
0
def get_regularizer(hparams=None):
    """Returns a variable regularizer instance.

    See :func:`~texar.tf.core.default_regularizer_hparams` for all
    hyperparameters and default values.

    The "type" field can be a subclass
    of :tf_main:`Regularizer <keras/regularizers/Regularizer>`, its string name
    or module path, or a class instance.

    Args:
        hparams (dict or HParams, optional): Hyperparameters. Missing
            hyperparameters are set to default values.

    Returns:
        A :tf_main:`Regularizer <keras/regularizers/Regularizer>` instance.
        `None` if :attr:`hparams` is `None` or taking the default
        hyperparameter value.

    Raises:
        ValueError: The resulting regularizer is not an instance of
            :tf_main:`Regularizer <keras/regularizers/Regularizer>`.
    """
    if hparams is None:
        return None

    if isinstance(hparams, dict):
        hparams = HParams(hparams, default_regularizer_hparams())

    rgl = utils.check_or_get_instance(
        hparams.type, hparams.kwargs.todict(),
        ["tensorflow.keras.regularizers", "texar.tf.custom"])

    if not isinstance(rgl, tf.keras.regularizers.Regularizer):
        raise ValueError("The regularizer must be an instance of "
                         "tf.keras.regularizers.Regularizer.")

    if isinstance(rgl, tf.keras.regularizers.L1L2) and \
            rgl.l1 == 0. and rgl.l2 == 0.:
        return None

    return rgl
Beispiel #18
0
def get_train_op(loss, variables=None,
                 optimizer=None, learning_rate=None,
                 global_step=None, increment_global_step=True, hparams=None):
    """Creates a training op.

    This is a wrapper of :tf_main:`tf.contrib.layers.optimize_loss
    <contrib/layers/optimize_loss>`.

    Args:
        loss: A scalar Tensor representing the loss to minimize.
        variables (optional): A list of Variables to optimize. If
            `None`, all trainable variables are used.
        optimizer (optional): An tf.train.Optimizer instance. If `None`,
            use the setting in `hparams` to create the optimizer.
        learning_rate (float or Tensor, optional): If `None`, learning rate
            specified in :attr:`hparams`, or the default learning rate
            of the optimizer will be used (if exists).
        global_step (optional): A scalar int Tensor. Step counter to update on
            each step unless :attr:`increment_global_step` is `False`.
            Learning rate decay uses :attr:`global_step`.
            If `None`, it will be fetched from the default graph (see
            :tf_main:`tf.train.get_global_step <train/get_global_step>` for
            more details). If it has not been created, no step will be
            incremented with each weight update.
        increment_global_step (bool): Whether to increment
            :attr:`global_step`. This is useful if the :attr:`global_step` is
            used in multiple training ops per training step (e.g. to optimize
            different parts of the model) to avoid incrementing
            :attr:`global_step` more times than necessary.
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically. See
            :func:`~texar.tf.core.default_optimization_hparams` for
            all hyperparameters and default values.

    Returns:
        train_op: the operator used for variables optimization.
    """
    hparams = HParams(hparams, default_optimization_hparams())
    grad_clip_fn = get_gradient_clip_fn(hparams["gradient_clip"])

    if not isinstance(optimizer, tf.train.Optimizer):
        opt_hparams = hparams["optimizer"]
        optimizer_fn, optimizer_class = get_optimizer_fn(opt_hparams)
        learning_rate = _get_static_lr(learning_rate, optimizer_class, hparams)
        lr_decay_fn = get_learning_rate_decay_fn(
            hparams["learning_rate_decay"])
        train_op = tf.contrib.layers.optimize_loss(
            loss=loss,
            global_step=global_step,
            learning_rate=learning_rate,
            optimizer=optimizer_fn,
            gradient_noise_scale=hparams["gradient_noise_scale"],
            clip_gradients=grad_clip_fn,
            learning_rate_decay_fn=lr_decay_fn,
            variables=variables,
            name=hparams["name"],
            increment_global_step=increment_global_step)

    else:
        train_op = tf.contrib.layers.optimize_loss(
            loss=loss,
            global_step=global_step,
            learning_rate=None,
            optimizer=optimizer,
            gradient_noise_scale=hparams["gradient_noise_scale"],
            clip_gradients=grad_clip_fn,
            variables=variables,
            name=hparams["name"],
            increment_global_step=increment_global_step)

    return train_op
Beispiel #19
0
    def test_hparams(self):
        """Tests the HParams class.
        """
        default_hparams = {
            "str": "str",
            "list": ['item1', 'item2'],
            "dict": {
                "key1": "value1",
                "key2": "value2"
            },
            "nested_dict": {
                "dict_l2": {
                    "key1_l2": "value1_l2"
                }
            },
            "type": "type",
            "kwargs": {
                "arg1": "argv1"
            },
        }

        # Test HParams.items() function
        hparams_ = HParams(None, default_hparams)
        names = []
        for name, _ in hparams_.items():
            names.append(name)
        self.assertEqual(set(names), set(default_hparams.keys()))

        hparams = {"dict": {"key1": "new_value"}, "kwargs": {"arg2": "argv2"}}

        hparams_ = HParams(hparams, default_hparams)

        # Test HParams construction
        self.assertEqual(hparams_.str, default_hparams["str"])
        self.assertEqual(hparams_.list, default_hparams["list"])
        self.assertEqual(hparams_.dict.key1, hparams["dict"]["key1"])
        self.assertEqual(hparams_.kwargs.arg2, hparams["kwargs"]["arg2"])
        self.assertEqual(hparams_.nested_dict.dict_l2.key1_l2,
                         default_hparams["nested_dict"]["dict_l2"]["key1_l2"])

        self.assertEqual(len(hparams_), len(default_hparams))

        new_hparams = copy.deepcopy(default_hparams)
        new_hparams["dict"]["key1"] = hparams["dict"]["key1"]
        new_hparams["kwargs"].update(hparams["kwargs"])
        self.assertEqual(hparams_.todict(), new_hparams)

        self.assertTrue("dict" in hparams_)

        self.assertIsNone(hparams_.get('not_existed_name', None))
        self.assertEqual(hparams_.get('str'), default_hparams['str'])

        # Test HParams update related operations
        hparams_.str = "new_str"
        hparams_.dict = {"key3": "value3"}
        self.assertEqual(hparams_.str, "new_str")
        self.assertEqual(hparams_.dict.key3, "value3")

        hparams_.add_hparam("added_str", "added_str")
        hparams_.add_hparam("added_dict", {"key4": "value4"})
        hparams_.kwargs.add_hparam("added_arg", "added_argv")
        self.assertEqual(hparams_.added_str, "added_str")
        self.assertEqual(hparams_.added_dict.todict(), {"key4": "value4"})
        self.assertEqual(hparams_.kwargs.added_arg, "added_argv")

        # Test HParams I/O
        hparams_file = tempfile.NamedTemporaryFile()
        pickle.dump(hparams_, hparams_file)
        with open(hparams_file.name, 'rb') as hparams_file:
            hparams_loaded = pickle.load(hparams_file)
        self.assertEqual(hparams_loaded.todict(), hparams_.todict())
Beispiel #20
0
 def __init__(self, hparams=None):
     self._hparams = HParams(hparams,
                             self.default_hparams(),
                             allow_new_hparam=True)
Beispiel #21
0
def get_learning_rate_decay_fn(hparams=None):
    """Creates learning rate decay function based on the hyperparameters.

    See the :attr:`learning_rate_decay` field in
    :meth:`~texar.tf.core.default_optimization_hparams` for all
    hyperparameters and default values.

    Args:
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically.

    Returns:
        function or None: If hparams["type"] is specified, returns a
        function that takes `(learning_rate, step, **kwargs)` and
        returns a decayed learning rate. If
        hparams["type"] is empty, returns `None`.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(
            hparams, default_optimization_hparams()["learning_rate_decay"])

    fn_type = hparams["type"]
    if fn_type is None or fn_type == "":
        return None

    fn_modules = ["tensorflow.train", "texar.tf.custom"]
    decay_fn = utils.get_function(fn_type, fn_modules)
    fn_kwargs = hparams["kwargs"]
    if fn_kwargs is HParams:
        fn_kwargs = fn_kwargs.todict()

    start_step = tf.cast(hparams["start_decay_step"], tf.int32)
    end_step = tf.cast(hparams["end_decay_step"], tf.int32)

    def lr_decay_fn(learning_rate, global_step):
        """Learning rate decay function.

        Args:
            learning_rate (float or Tensor): The original learning rate.
            global_step (int or scalar int Tensor): optimization step counter.

        Returns:
            scalar float Tensor: decayed learning rate.
        """
        offset_global_step = tf.maximum(
            tf.minimum(tf.cast(global_step, tf.int32), end_step) - start_step,
            0)
        if decay_fn == tf.train.piecewise_constant:
            decayed_lr = decay_fn(x=offset_global_step, **fn_kwargs)
        else:
            fn_kwargs_ = {
                "learning_rate": learning_rate,
                "global_step": offset_global_step}
            fn_kwargs_.update(fn_kwargs)
            decayed_lr = utils.call_function_with_redundant_kwargs(
                decay_fn, fn_kwargs_)

            decayed_lr = tf.maximum(decayed_lr, hparams["min_learning_rate"])

        return decayed_lr

    return lr_decay_fn
Beispiel #22
0
def get_layer(hparams):
    r"""Makes a layer instance.

    The layer must be an instance of :tf_main:`tf.layers.Layer <layers/Layer>`.

    Args:
        hparams (dict or HParams): Hyperparameters of the layer, with
            structure:

            .. code-block:: python

                {
                    "type": "LayerClass",
                    "kwargs": {
                        # Keyword arguments of the layer class
                        # ...
                    }
                }

            Here:

            `"type"`: str or layer class or layer instance
                The layer type. This can be

                - The string name or full module path of a layer class. If
                  the class name is provided, the class must be in module
                  :tf_main:`tf.layers <layers>`, :mod:`texar.tf.core`,
                  or :mod:`texar.tf.custom`.
                - A layer class.
                - An instance of a layer class.

                For example

                .. code-block:: python

                    "type": "Conv1D" # class name
                    "type": "texar.tf.core.MaxReducePooling1D" # module path
                    "type": "my_module.MyLayer" # module path
                    "type": tf.layers.Conv2D # class
                    "type": Conv1D(filters=10, kernel_size=2) # cell instance
                    "type": MyLayer(...) # cell instance

            `"kwargs"`: dict
                A dictionary of keyword arguments for constructor of the
                layer class. Ignored if :attr:`"type"` is a layer instance.

                - Arguments named "activation" can be a callable,
                  or a `str` of the name or module path to the activation
                  function.
                - Arguments named "\*_regularizer" and "\*_initializer"
                  can be a class instance, or a `dict` of hyperparameters of
                  respective regularizers and initializers. See
                - Arguments named "\*_constraint" can be a callable, or a
                  `str` of the name or full path to the constraint function.

    Returns:
        A layer instance. If ``hparams["type"]`` is a layer instance, returns it
        directly.

    Raises:
        ValueError: If :attr:`hparams` is `None`.
        ValueError: If the resulting layer is not an instance of
            :tf_main:`tf.layers.Layer <layers/Layer>`.
    """
    if hparams is None:
        raise ValueError("`hparams` must not be `None`.")

    layer_type = hparams["type"]
    if not is_str(layer_type) and not isinstance(layer_type, type):
        layer = layer_type
    else:
        layer_modules = [
            "tensorflow.layers", "texar.tf.core", "texar.tf.custom"
        ]
        layer_class = utils.check_or_get_class(layer_type, layer_modules)
        if isinstance(hparams, dict):
            default_kwargs = _layer_class_to_default_kwargs_map.get(
                layer_class, {})
            default_hparams = {"type": layer_type, "kwargs": default_kwargs}
            hparams = HParams(hparams, default_hparams)

        kwargs = {}
        for k, v in hparams.kwargs.items():
            if k.endswith('_regularizer'):
                kwargs[k] = get_regularizer(v)
            elif k.endswith('_initializer'):
                kwargs[k] = get_initializer(v)
            elif k.endswith('activation'):
                kwargs[k] = get_activation_fn(v)
            elif k.endswith('_constraint'):
                kwargs[k] = get_constraint_fn(v)
            else:
                kwargs[k] = v
        layer = utils.get_instance(layer_type, kwargs, layer_modules)

    if not isinstance(layer, tf.layers.Layer):
        raise ValueError("layer must be an instance of `tf.layers.Layer`.")

    return layer
Beispiel #23
0
def get_rnn_cell(hparams=None, mode=None):
    """Creates an RNN cell.

    See :func:`~texar.tf.core.default_rnn_cell_hparams` for all
    hyperparameters and default values.

    Args:
        hparams (dict or HParams, optional): Cell hyperparameters. Missing
            hyperparameters are set to default values.
        mode (optional): A Tensor taking value in
            :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
            `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout will be
            controlled by :func:`texar.tf.global_mode`.

    Returns:
        A cell instance.

    Raises:
        ValueError: If hparams["num_layers"]>1 and hparams["type"] is a class
            instance.
        ValueError: The cell is not an
            :tf_main:`RNNCell <contrib/rnn/RNNCell>` instance.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(hparams, default_rnn_cell_hparams())

    d_hp = hparams["dropout"]
    if d_hp["variational_recurrent"] and \
            len(d_hp["input_size"]) != hparams["num_layers"]:
        raise ValueError(
            "If variational_recurrent=True, input_size must be a list of "
            "num_layers(%d) integers. Got len(input_size)=%d." %
            (hparams["num_layers"], len(d_hp["input_size"])))

    cells = []
    cell_kwargs = hparams["kwargs"].todict()
    num_layers = hparams["num_layers"]
    for layer_i in range(num_layers):
        # Create the basic cell
        cell_type = hparams["type"]
        if not is_str(cell_type) and not isinstance(cell_type, type):
            if num_layers > 1:
                raise ValueError(
                    "If 'num_layers'>1, then 'type' must be a cell class or "
                    "its name/module path, rather than a cell instance.")
        cell_modules = [
            'tensorflow.nn.rnn_cell', 'tensorflow.contrib.rnn',
            'texar.tf.custom'
        ]
        cell = utils.check_or_get_instance(cell_type, cell_kwargs,
                                           cell_modules, rnn.RNNCell)

        # Optionally add dropout
        if d_hp["input_keep_prob"] < 1.0 or \
                d_hp["output_keep_prob"] < 1.0 or \
                d_hp["state_keep_prob"] < 1.0:
            vr_kwargs = {}
            if d_hp["variational_recurrent"]:
                vr_kwargs = {
                    "variational_recurrent": True,
                    "input_size": d_hp["input_size"][layer_i],
                    "dtype": tf.float32
                }
            input_keep_prob = switch_dropout(d_hp["input_keep_prob"], mode)
            output_keep_prob = switch_dropout(d_hp["output_keep_prob"], mode)
            state_keep_prob = switch_dropout(d_hp["state_keep_prob"], mode)
            cell = rnn.DropoutWrapper(cell=cell,
                                      input_keep_prob=input_keep_prob,
                                      output_keep_prob=output_keep_prob,
                                      state_keep_prob=state_keep_prob,
                                      **vr_kwargs)

        # Optionally add residual and highway connections
        if layer_i > 0:
            if hparams["residual"]:
                cell = rnn.ResidualWrapper(cell)
            if hparams["highway"]:
                cell = rnn.HighwayWrapper(cell)

        cells.append(cell)

    if hparams["num_layers"] > 1:
        cell = rnn.MultiRNNCell(cells)
    else:
        cell = cells[0]

    return cell
Beispiel #24
0
    def __init__(self, hparams=None):
        self._hparams = HParams(hparams, self.default_hparams())

        name = self._hparams.name
        self._variable_scope = get_unique_named_variable_scope(name)
        self._unique_name = self._variable_scope.name.split("/")[-1]
Beispiel #25
0
    def __init__(self,
                 pretrained_model_name=None,
                 cache_dir=None,
                 hparams=None):
        PretrainedBase.__init__(self, pretrained_model_name, cache_dir,
                                hparams)

        if self.pretrained_model_dir:
            self._hparams = HParams(self.pretrained_model_hparams,
                                    self._hparams.todict())

        num_layers = self._hparams.num_layers
        use_segments = self._hparams.use_segments
        untie_r = self._hparams.untie_r

        with tf.variable_scope(self.variable_scope):

            if self._hparams.initializer:
                tf.get_variable_scope().set_initializer(
                    layers.get_initializer(self._hparams.initializer))

            if untie_r:
                self.r_w_bias = tf.get_variable('r_w_bias', [
                    num_layers, self._hparams.num_heads, self._hparams.head_dim
                ],
                                                dtype=tf.float32)
                self.r_r_bias = tf.get_variable('r_r_bias', [
                    num_layers, self._hparams.num_heads, self._hparams.head_dim
                ],
                                                dtype=tf.float32)
            else:
                self.r_w_bias = tf.get_variable(
                    'r_w_bias',
                    [self._hparams.num_heads, self._hparams.head_dim],
                    dtype=tf.float32)
                self.r_r_bias = tf.get_variable(
                    'r_r_bias',
                    [self._hparams.num_heads, self._hparams.head_dim],
                    dtype=tf.float32)

            if use_segments:
                self.segment_embed = tf.get_variable('seg_embed', [
                    num_layers, 2, self._hparams.num_heads,
                    self._hparams.head_dim
                ],
                                                     dtype=tf.float32)
                self.r_s_bias = (tf.get_variable(
                    'r_s_bias', [
                        num_layers, self._hparams.num_heads,
                        self._hparams.head_dim
                    ],
                    dtype=tf.float32) if untie_r else tf.get_variable(
                        'r_s_bias',
                        [self._hparams.num_heads, self._hparams.head_dim],
                        dtype=tf.float32))
            else:
                self.segment_embed = None
                self.r_s_bias = None

            # Word embedding
            self.word_embedder = WordEmbedder(
                vocab_size=self._hparams.vocab_size,
                hparams={"dim": self._hparams.hidden_dim})

            # Position embedding
            self.pos_embed = RelativePositionalEncoding(
                hparams={
                    "dim": self._hparams.hidden_dim,
                    "max_seq_len": self._hparams.max_seq_len
                })

            self.attn_layers = []
            self.ff_layers = []
            rel_attn_hparams = dict_fetch(
                self._hparams, RelativeMutiheadAttention.default_hparams())
            rel_attn_hparams["name"] = "rel_attn"

            ff_hparams = dict_fetch(self._hparams,
                                    PositionWiseFF.default_hparams())
            ff_hparams["name"] = "ff"

            for i in range(num_layers):
                with tf.variable_scope("layer_{}".format(i)):
                    if self._hparams.untie_r:
                        if use_segments:
                            self.attn_layers.append(
                                RelativeMutiheadAttention(
                                    self.r_r_bias[i],
                                    self.r_w_bias[i],
                                    self.r_s_bias[i],
                                    self.segment_embed[i],
                                    hparams=rel_attn_hparams))
                        else:
                            self.attn_layers.append(
                                RelativeMutiheadAttention(
                                    self.r_r_bias[i],
                                    self.r_w_bias[i],
                                    hparams=rel_attn_hparams))
                    else:
                        if use_segments:
                            self.attn_layers.append(
                                RelativeMutiheadAttention(
                                    self.r_r_bias,
                                    self.r_w_bias,
                                    self.r_s_bias,
                                    self.segment_embed[i],
                                    hparams=rel_attn_hparams))
                        else:
                            self.attn_layers.append(
                                RelativeMutiheadAttention(
                                    self.r_r_bias,
                                    self.r_w_bias,
                                    hparams=rel_attn_hparams))
                    self.ff_layers.append(PositionWiseFF(hparams=ff_hparams))

            dropout_hparams = {
                "type": "Dropout",
                "kwargs": {
                    "rate": self._hparams.dropout
                }
            }
            self.dropout = layers.get_layer(hparams=dropout_hparams)

            self.mask_embed = tf.get_variable('mask_emb',
                                              [1, 1, self.hparams.hidden_dim],
                                              dtype=tf.float32)
Beispiel #26
0
class BertEncoder(PretrainedBase, EncoderBase):
    """Raw BERT Transformer for encoding sequences.

    This module basically stacks
    :class:`~texar.tf.modules.embedders.WordEmbedder`,
    :class:`~texar.tf.modules.embedders.PositionEmbedder`,
    :class:`~texar.tf.modules.encoders.TransformerEncoder` and a dense pooler.

    This module supports the architecture first proposed
    in `(Devlin et al.)` BERT.

    Args:
        pretrained_model_name (optional): a str with the name
            of a pre-trained model to load selected in the list of:
            `bert-base-uncased`, `bert-large-uncased`, `bert-base-cased`,
            `bert-large-cased`, `bert-base-multilingual-uncased`,
            `bert-base-multilingual-cased`, `bert-base-chinese`.
            If `None`, will use the model name in :attr:`hparams`.
        cache_dir (optional): the path to a folder in which the
            pre-trained models will be cached. If `None` (default),
            a default directory will be used.
        hparams (dict or HParams, optional): Hyperparameters. Missing
            hyperparameter will be set to default values. See
            :meth:`default_hparams` for the hyperparameter sturcture
            and default values.

    .. document private functions
    .. automethod:: _build
    """

    model_name = "BERT"

    def __init__(self,
                 pretrained_model_name=None,
                 cache_dir=None,
                 hparams=None):
        PretrainedBase.__init__(self, pretrained_model_name, cache_dir,
                                hparams)
        if self.pretrained_model_dir:
            self._hparams = HParams(self.pretrained_model_hparams,
                                    self._hparams.todict())

        with tf.variable_scope(self.variable_scope):
            if self._hparams.initializer:
                tf.get_variable_scope().set_initializer(
                    layers.get_initializer(self._hparams.initializer))

            # Word embedding
            self.word_embedder = WordEmbedder(
                vocab_size=self._hparams.vocab_size,
                hparams=self._hparams.embed)

            # Segment embedding for each type of tokens
            self.segment_embedder = WordEmbedder(
                vocab_size=self._hparams.type_vocab_size,
                hparams=self._hparams.segment_embed)

            # Position embedding
            self.position_embedder = PositionEmbedder(
                position_size=self._hparams.position_size,
                hparams=self._hparams.position_embed)

            # The BERT encoder (a TransformerEncoder)
            self.encoder = TransformerEncoder(hparams=self._hparams.encoder)

            with tf.variable_scope("pooler"):
                kwargs_i = {
                    "units": self._hparams.hidden_size,
                    "activation": tf.tanh
                }
                layer_hparams = {"type": "Dense", "kwargs": kwargs_i}
                self.pooler = layers.get_layer(hparams=layer_hparams)

    @staticmethod
    def default_hparams():
        """Returns a dictionary of hyperparameters with default values.

        * The encoder arch is determined by the constructor argument \
        :attr:`pretrained_model_name` if it's specified. In this case, \
        hparams are ignored.
        * Otherwise, the encoder arch is determined by \
        `hparams['pretrained_model_name']` if it's specified. All other \
        configs in hparams are ignored.
        * If the above two are `None`, the encoder arch is defined by \
        the configs in hparams and weights are randomly initialized.

        .. code-block:: python

            {
                'pretrained_model_name': 'bert-base-uncased',
                'embed': {
                    'dim': 768,
                    'name': 'word_embeddings'
                },
                'vocab_size': 30522,
                'segment_embed': {
                    'dim': 768,
                    'name': 'token_type_embeddings'
                },
                'type_vocab_size': 2,
                'position_embed': {
                    'dim': 768,
                    'name': 'position_embeddings'
                },
                'position_size': 512,

                'encoder': {
                    'dim': 768,
                    'embedding_dropout': 0.1,
                    'multihead_attention': {
                        'dropout_rate': 0.1,
                        'name': 'self',
                        'num_heads': 12,
                        'num_units': 768,
                        'output_dim': 768,
                        'use_bias': True
                    },
                    'name': 'encoder',
                    'num_blocks': 12,
                    'poswise_feedforward': {
                        'layers': [
                            {   'kwargs': {
                                    'activation': 'gelu',
                                    'name': 'intermediate',
                                    'units': 3072,
                                    'use_bias': True
                                },
                                'type': 'Dense'
                            },
                            {   'kwargs': {'activation': None,
                                'name': 'output',
                                'units': 768,
                                'use_bias': True
                                },
                                'type': 'Dense'
                            }
                        ]
                    },
                    'residual_dropout': 0.1,
                    'use_bert_config': True
                },
                'hidden_size': 768,
                'initializer': None,
                'name': 'bert_encoder'
            }



        Here:

        The default parameters are values for uncased BERT-Base model.


        "pretrained_model_name": str or None
             The name of the pretrained bert model. If None, the model
             will be randomly initialized.

        "embed": dict
            Hyperparameters for word embedding layer.

        "vocab_size": int
            The vocabulary size of `inputs` in `BertModel`.

        "segment_embed": dict
            Hyperparameters for segment embedding layer.

        "type_vocab_size": int
            The vocabulary size of the `segment_ids` passed into `BertModel`.

        "position_embed": dict
            Hyperparameters for position embedding layer.

        "position_size":  int
            The maximum sequence length that this model might ever be used with.

        "encoder": dict
            Hyperparameters for the TransformerEncoder.
            See :func:`~texar.tf.modules.TransformerEncoder.default_harams`
            for details.

        "hidden_size": int
            Size of the pooler dense layer.

        "initializer": dict, optional
            Hyperparameters of the default initializer that initializes
            variables created in this module.
            See :func:`~texar.tf.core.get_initializer` for details.

        "name": str
            Name of the module.
        """

        return {
            'pretrained_model_name': 'bert-base-uncased',
            'embed': {
                'dim': 768,
                'name': 'word_embeddings'
            },
            'vocab_size': 30522,
            'segment_embed': {
                'dim': 768,
                'name': 'token_type_embeddings'
            },
            'type_vocab_size': 2,
            'position_embed': {
                'dim': 768,
                'name': 'position_embeddings'
            },
            'position_size': 512,
            'encoder': {
                'dim': 768,
                'embedding_dropout': 0.1,
                'multihead_attention': {
                    'dropout_rate': 0.1,
                    'name': 'self',
                    'num_heads': 12,
                    'num_units': 768,
                    'output_dim': 768,
                    'use_bias': True
                },
                'name': 'encoder',
                'num_blocks': 12,
                'poswise_feedforward': {
                    'layers': [{
                        'kwargs': {
                            'activation': 'gelu',
                            'name': 'intermediate',
                            'units': 3072,
                            'use_bias': True
                        },
                        'type': 'Dense'
                    }, {
                        'kwargs': {
                            'activation': None,
                            'name': 'output',
                            'units': 768,
                            'use_bias': True
                        },
                        'type': 'Dense'
                    }]
                },
                'residual_dropout': 0.1,
                'use_bert_config': True
            },
            'hidden_size': 768,
            'initializer': None,
            'name': 'bert_encoder',
            '@no_typecheck': ['pretrained_model_name']
        }

    def _build(self,
               inputs,
               sequence_length=None,
               segment_ids=None,
               mode=None,
               **kwargs):
        """Encodes the inputs.

        Args:
            inputs: A 2D Tensor of shape `[batch_size, max_time]`,
                containing the token ids of tokens in the input sequences.
            segment_ids (optional): A 2D Tensor of shape
                `[batch_size, max_time]`, containing the segment ids
                of tokens in input sequences. If `None` (default), a
                tensor with all elements set to zero is used.
            sequence_length (optional): A 1D Tensor of shape `[batch_size]`.
                Input tokens beyond respective sequence lengths are masked
                out automatically.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
                including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
                dropout.
                If `None` (default), :func:`texar.tf.global_mode` is used.
            **kwargs: Keyword arguments.

        Returns:
            A pair :attr:`(outputs, pooled_output)`

                - :attr:`outputs`:  A Tensor of shape \
                `[batch_size, max_time, dim]` containing the \
                 encoded vectors.

                - :attr:`pooled_output`: A Tensor of size \
                `[batch_size, hidden_size]` which is the output of a \
                pooler pretrained on top of the hidden state associated \
                to the first character of the input (`CLS`), see BERT's \
                paper.
        """

        if segment_ids is None:
            segment_ids = tf.zeros_like(inputs)

        word_embeds = self.word_embedder(inputs)

        segment_embeds = self.segment_embedder(segment_ids)

        batch_size = tf.shape(inputs)[0]
        pos_length = tf.ones([batch_size], tf.int32) * tf.shape(inputs)[1]
        pos_embeds = self.position_embedder(sequence_length=pos_length)

        input_embeds = word_embeds + segment_embeds + pos_embeds

        if sequence_length is None:
            sequence_length = tf.ones([batch_size], tf.int32) \
                              * tf.shape(inputs)[1]

        output = self.encoder(input_embeds, sequence_length, mode)

        with tf.variable_scope("pooler"):
            # taking the hidden state corresponding to the first token.
            first_token_tensor = tf.squeeze(output[:, 0:1, :], axis=1)
            pooled_output = self.pooler(first_token_tensor)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

            if self.pretrained_model_dir:
                bert_utils.init_bert_checkpoint(self.pretrained_model_dir,
                                                self.variable_scope.name)

        return output, pooled_output
Beispiel #27
0
class PretrainedMixin(ModuleBase):
    r"""A mixin class for all pre-trained classes to inherit.
    """
    __metaclass__ = ABCMeta

    _MODEL_NAME = None
    _MODEL2URL = None

    pretrained_model_dir = None

    @classmethod
    def available_checkpoints(cls):
        return list(cls._MODEL2URL.keys())

    def _name_to_variable(self, name):
        r"""Find the corresponding variable given the specified name.
        """
        pointer = self
        for m_name in name.split("."):
            if m_name.isdigit():
                num = int(m_name)
                pointer = pointer[num]  # type: ignore
            else:
                pointer = getattr(pointer, m_name)
        return pointer  # type: ignore

    def load_pretrained_config(self,
                               pretrained_model_name=None,
                               cache_dir=None,
                               hparams=None):
        r"""Load paths and configurations of the pre-trained model.

        Args:
            pretrained_model_name (optional): A str with the name
                of a pre-trained model to load. If `None`, will use the model
                name in :attr:`hparams`.
            cache_dir (optional): The path to a folder in which the
                pre-trained models will be cached. If `None` (default),
                a default directory will be used.
            hparams (dict or HParams, optional): Hyperparameters. Missing
                hyperparameter will be set to default values. See
                :meth:`default_hparams` for the hyperparameter structure
                and default values.
        """
        if not hasattr(self, "_hparams"):
            self._hparams = HParams(hparams, self.default_hparams())
        else:
            # Probably already parsed by subclasses. We rely on subclass
            # implementations to get this right.
            # As a sanity check, we require `hparams` to be `None` in this case.
            if hparams is not None:
                raise ValueError(
                    "`self._hparams` is already assigned, but `hparams` "
                    "argument is not None.")

        self.pretrained_model_dir = None
        self.pretrained_model_name = pretrained_model_name

        if self.pretrained_model_name is None:
            self.pretrained_model_name = self._hparams.pretrained_model_name
        if self.pretrained_model_name is not None:
            self.pretrained_model_dir = self.download_checkpoint(
                self.pretrained_model_name, cache_dir)
            pretrained_model_hparams = self._transform_config(
                self.pretrained_model_name, self.pretrained_model_dir)
            self._hparams = HParams(
                pretrained_model_hparams, self._hparams.todict())

    def init_pretrained_weights(self, scope_name, **kwargs):
        if self.pretrained_model_dir:
            self._init_from_checkpoint(
                self.pretrained_model_name,
                self.pretrained_model_dir, scope_name, **kwargs)
        else:
            self.reset_parameters()

    def reset_parameters(self):
        r"""Initialize parameters of the pre-trained model. This method is only
        called if pre-trained checkpoints are not loaded.
        """
        pass

    @staticmethod
    def default_hparams():
        r"""Returns a dictionary of hyperparameters with default values.

        .. code-block:: python

            {
                "pretrained_model_name": None,
                "name": "pretrained_base"
            }
        """
        return {
            'pretrained_model_name': None,
            'name': "pretrained_base",
            '@no_typecheck': ['pretrained_model_name']
        }

    @classmethod
    def download_checkpoint(cls, pretrained_model_name, cache_dir=None):
        r"""Download the specified pre-trained checkpoint, and return the
        directory in which the checkpoint is cached.

        Args:
            pretrained_model_name (str): Name of the model checkpoint.
            cache_dir (str, optional): Path to the cache directory. If `None`,
                uses the default directory (user's home directory).

        Returns:
            Path to the cache directory.
        """
        if pretrained_model_name in cls._MODEL2URL:
            download_path = cls._MODEL2URL[pretrained_model_name]
        else:
            raise ValueError(
                "Pre-trained model not found: {}".format(pretrained_model_name))

        if cache_dir is None:
            cache_path = default_download_dir(cls._MODEL_NAME)
        else:
            cache_path = Path(cache_dir)
        cache_path = cache_path / pretrained_model_name

        if not cache_path.exists():
            if isinstance(download_path, list):
                for path in download_path:
                    maybe_download(path, str(cache_path))
            else:
                filename = download_path.split('/')[-1]
                maybe_download(download_path, str(cache_path), extract=True)
                folder = None
                for file in cache_path.iterdir():
                    if file.is_dir():
                        folder = file
                assert folder is not None
                (cache_path / filename).unlink()
                for file in folder.iterdir():
                    file.rename(file.parents[1] / file.name)
                folder.rmdir()
            print("Pre-trained {} checkpoint {} cached to {}".format(
                cls._MODEL_NAME, pretrained_model_name, cache_path))
        else:
            print("Using cached pre-trained {} checkpoint from {}.".format(
                cls._MODEL_NAME, cache_path))

        return str(cache_path)

    @classmethod
    @abstractmethod
    def _transform_config(cls, pretrained_model_name, cache_dir):
        r"""Load the official configuration file and transform it into
        Texar-style hyperparameters.

        Args:
            pretrained_model_name (str): Name of the pre-trained model.
            cache_dir (str): Path to the cache directory.

        Returns:
            dict: Texar module hyperparameters.
        """
        raise NotImplementedError

    @abstractmethod
    def _init_from_checkpoint(self, pretrained_model_name, cache_dir,
                              scope_name, **kwargs):
        r"""Initialize model parameters from weights stored in the pre-trained
        checkpoint.

        Args:
            pretrained_model_name (str): Name of the pre-trained model.
            cache_dir (str): Path to the cache directory.
            scope_name: Variable scope.
            **kwargs: Additional arguments for specific models.
        """
        raise NotImplementedError
Beispiel #28
0
 def __init__(self, hparams=None):
     self._hparams = HParams(hparams, self.default_hparams())
Beispiel #29
0
class XLNetEncoder(PretrainedBase, EncoderBase):
    r"""XLNet Transformer for encoding sequences.

    This module supports the architecture proposed
    in `(Zhiling et al.)` XLNet.

    Args:
        pretrained_model_name (optional): a str with the name
            of a pre-trained model to load. Currently 'xlnet-large-cased'
            and 'xlnet-base-cased' are supported.
            If `None`, will use the model name in :attr:`hparams`.
        cache_dir (optional): the path to a folder in which the
            pre-trained models will be cached. If `None` (default),
            a default directory will be used.
        hparams (dict or HParams, optional): Hyperparameters. Missing
            hyperparameter will be set to default values. See
            :meth:`default_hparams` for the hyperparameter sturcture
            and default values.

    .. document private functions
    .. automethod:: _build
    """

    model_name = "XLNet"

    def __init__(self,
                 pretrained_model_name=None,
                 cache_dir=None,
                 hparams=None):
        PretrainedBase.__init__(self, pretrained_model_name, cache_dir,
                                hparams)

        if self.pretrained_model_dir:
            self._hparams = HParams(self.pretrained_model_hparams,
                                    self._hparams.todict())

        num_layers = self._hparams.num_layers
        use_segments = self._hparams.use_segments
        untie_r = self._hparams.untie_r

        with tf.variable_scope(self.variable_scope):

            if self._hparams.initializer:
                tf.get_variable_scope().set_initializer(
                    layers.get_initializer(self._hparams.initializer))

            if untie_r:
                self.r_w_bias = tf.get_variable('r_w_bias', [
                    num_layers, self._hparams.num_heads, self._hparams.head_dim
                ],
                                                dtype=tf.float32)
                self.r_r_bias = tf.get_variable('r_r_bias', [
                    num_layers, self._hparams.num_heads, self._hparams.head_dim
                ],
                                                dtype=tf.float32)
            else:
                self.r_w_bias = tf.get_variable(
                    'r_w_bias',
                    [self._hparams.num_heads, self._hparams.head_dim],
                    dtype=tf.float32)
                self.r_r_bias = tf.get_variable(
                    'r_r_bias',
                    [self._hparams.num_heads, self._hparams.head_dim],
                    dtype=tf.float32)

            if use_segments:
                self.segment_embed = tf.get_variable('seg_embed', [
                    num_layers, 2, self._hparams.num_heads,
                    self._hparams.head_dim
                ],
                                                     dtype=tf.float32)
                self.r_s_bias = (tf.get_variable(
                    'r_s_bias', [
                        num_layers, self._hparams.num_heads,
                        self._hparams.head_dim
                    ],
                    dtype=tf.float32) if untie_r else tf.get_variable(
                        'r_s_bias',
                        [self._hparams.num_heads, self._hparams.head_dim],
                        dtype=tf.float32))
            else:
                self.segment_embed = None
                self.r_s_bias = None

            # Word embedding
            self.word_embedder = WordEmbedder(
                vocab_size=self._hparams.vocab_size,
                hparams={"dim": self._hparams.hidden_dim})

            # Position embedding
            self.pos_embed = RelativePositionalEncoding(
                hparams={
                    "dim": self._hparams.hidden_dim,
                    "max_seq_len": self._hparams.max_seq_len
                })

            self.attn_layers = []
            self.ff_layers = []
            rel_attn_hparams = dict_fetch(
                self._hparams, RelativeMutiheadAttention.default_hparams())
            rel_attn_hparams["name"] = "rel_attn"

            ff_hparams = dict_fetch(self._hparams,
                                    PositionWiseFF.default_hparams())
            ff_hparams["name"] = "ff"

            for i in range(num_layers):
                with tf.variable_scope("layer_{}".format(i)):
                    if self._hparams.untie_r:
                        if use_segments:
                            self.attn_layers.append(
                                RelativeMutiheadAttention(
                                    self.r_r_bias[i],
                                    self.r_w_bias[i],
                                    self.r_s_bias[i],
                                    self.segment_embed[i],
                                    hparams=rel_attn_hparams))
                        else:
                            self.attn_layers.append(
                                RelativeMutiheadAttention(
                                    self.r_r_bias[i],
                                    self.r_w_bias[i],
                                    hparams=rel_attn_hparams))
                    else:
                        if use_segments:
                            self.attn_layers.append(
                                RelativeMutiheadAttention(
                                    self.r_r_bias,
                                    self.r_w_bias,
                                    self.r_s_bias,
                                    self.segment_embed[i],
                                    hparams=rel_attn_hparams))
                        else:
                            self.attn_layers.append(
                                RelativeMutiheadAttention(
                                    self.r_r_bias,
                                    self.r_w_bias,
                                    hparams=rel_attn_hparams))
                    self.ff_layers.append(PositionWiseFF(hparams=ff_hparams))

            dropout_hparams = {
                "type": "Dropout",
                "kwargs": {
                    "rate": self._hparams.dropout
                }
            }
            self.dropout = layers.get_layer(hparams=dropout_hparams)

            self.mask_embed = tf.get_variable('mask_emb',
                                              [1, 1, self.hparams.hidden_dim],
                                              dtype=tf.float32)

    @staticmethod
    def default_hparams():
        r"""Returns a dictionary of hyperparameters with default values.

        * The encoder arch is determined by the constructor argument \
        :attr:`pretrained_model_name` if it's specified. In this case, \
        hparams are ignored.
        * Otherwise, the encoder arch is determined by \
        `hparams['pretrained_model_name']` if it's specified. All other \
        configs in hparams are ignored.
        * If the above two are `None`, the encoder arch is defined by \
        the configs in hparams and weights are randomly initialized.

        .. code-block:: python

            {
                "name": "xlnet_encoder",
                "pretrained_model_name": "xlnet-base-cased",
                "untie_r": True,
                "num_layers": 12,
                "mem_len": 0,
                "reuse_len": 0,
                "initializer": None,
                "num_heads": 12,
                "hidden_dim": 768,
                "head_dim": 64,
                "dropout": 0.1,
                "attention_dropout": 0.1,
                "use_segments": True,
                "ffn_inner_dim": 3072,
                "activation": 'gelu',
                "vocab_size": 32000,
                "max_seq_len": 512,
            }



        Here:

        The default parameters are values for cased XLNet-Base model.


        "pretrained_model_name": str or None
             The name of the pretrained bert model. If None, the model
             will be randomly initialized.

        "untie_r": bool
            Boolean value to indicate if biases should be untied for all the
            layers

        "num_layers": int
            Number of layers in the network

        "mem_len": int
            Length of the memory to be used during attention score calculation.

        "reuse_len": int
            Length of the memory that can be re-used

        "initializer": dict, optional
            Hyperparameters of the default initializer that initializes
            variables created in this module.
            See :func:`~texar.tf.core.get_initializer` for details.

        "num_heads": int
            Number of heads in the attention

        "hidden_dim": int
            Hidden dimension of the embeddings

        "head_dim": int
            Size of the vectors after head projection.

        "dropout": float
            Dropout rate for layers

        "attention_dropout": float
            Dropout rate for attention layers

        "use_segments": bool
            Boolean to indicate if the input has segments

        "ffn_inner_dim": int
            Dimension of PositionWise FF network's hidden layer

        "activation": str or callable
            Activation function applied to the output of the PositionWise FF.
            See :func:`~texar.tf.core.get_activation_fn` for more details.

        "vocab_size": int
            The vocabulary size of `inputs` in `XLNet`.

        "max_seq_len": int
            Maximum len of the sequence allowed in one segment

        "name": str
            Name of the module.
        """

        return {
            "name": "xlnet_encoder",
            'pretrained_model_name': 'xlnet-base-cased',
            "untie_r": True,
            "num_layers": 12,
            "mem_len": 0,
            "reuse_len": 0,
            # initializer
            "initializer": None,
            # layer
            "num_heads": 12,
            "hidden_dim": 768,
            "head_dim": 64,
            "dropout": 0.1,
            "attention_dropout": 0.1,
            "use_segments": True,
            # ffn
            "ffn_inner_dim": 3072,
            "activation": 'gelu',
            # embedding
            "vocab_size": 32000,
            "max_seq_len": 512,
            '@no_typecheck': ['pretrained_model_name']
        }

    def param_groups(self,
                     lr=None,
                     lr_layer_scale=1.0,
                     decay_base_params=False):
        r"""Create parameter groups for optimizers. When
        :attr:`lr_layer_decay_rate` is not 1.0, parameters from each layer form
        separate groups with different base learning rates.

        This method should be called before applying gradients to the variables
        through the optimizer. Particularly, after calling the optimizer's
        `compute_gradients` method, the user can call this method to get
        variable-specific learning rates for the network. The gradients for each
        variables can then be scaled accordingly. These scaled gradients are
        finally applied by calling optimizer's `apply_gradients` method.

        Example:

            .. code-block:: python

            grads_and_vars = optimizer.compute_gradients(loss)

            vars_to_grads = {key: value for key, value in grads_and_vars}

            vars_to_learning_rates = xlnet_encoder.param_groups(
                                        lr=1, ly_layer_scale=0.75)

            for key in vars_to_grads.keys():
                vars_to_grads[key] *= vars_to_learning_rates[key]

            train_op = optimizer.apply_gradients(zip(*vars_to_grads.items()))


        Args:
            lr (float): The learning rate. Can be omitted if
                :attr:`lr_layer_decay_rate` is 1.0.
            lr_layer_scale (float): Per-layer LR scaling rate. The `i`-th layer
                will be scaled by `lr_layer_scale ^ (num_layers - i - 1)`.
            decay_base_params (bool): If `True`, treat non-layer parameters
                (e.g. embeddings) as if they're in layer 0. If `False`, these
                parameters are not scaled.

        Returns: A dict mapping tensorflow variables to their learning rates.
        """
        vars_to_learning_rates = {}
        if lr_layer_scale != 1.0:
            if lr is None:
                raise ValueError(
                    "lr must be specified when lr_layer_decay_rate is not 1.0")

            num_layers = self._hparams.num_layers
            scope = self.variable_scope.name
            base_var_names = ['r_w_bias', 'r_r_bias', 'word_embedder']

            if self._hparams.use_segments:
                base_var_names.extend(['r_s_bias', 'seg_embed'])

            for var in base_var_names:
                tf_variable = tf.trainable_variables(scope=scope + "/" +
                                                     var)[0]
                vars_to_learning_rates[tf_variable] = \
                    lr * (lr_layer_scale ** num_layers if decay_base_params
                          else 1.0)

            for idx in range(num_layers):
                decay_rate = lr_layer_scale**(num_layers - idx - 1)
                layer_variables = tf.trainable_variables(
                    scope=scope + "/" + "layer_{}".format(idx))
                for variable in layer_variables:
                    vars_to_learning_rates[variable] = lr * decay_rate
        else:
            for variable in self.trainable_variables:
                vars_to_learning_rates[variable] = lr

        return vars_to_learning_rates

    @property
    def output_size(self):
        r"""The last dimension of the encoder output.

        Note: The :meth:`_build` returns two tensors of shapes
        `[batch_size, max_time, hidden_dim]` and
        `[batch_size, cache_len, hidden_dim]`. `output_size` here equals
        `hidden_dim`
        """
        return self._hparams.hidden_dim

    @staticmethod
    def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
        r"""Cache hidden states into memory."""
        assert mem_len > 0

        if reuse_len is not None and reuse_len > 0:
            curr_out = curr_out[:reuse_len]

        if prev_mem is None:
            new_mem = curr_out[-mem_len:]
        else:
            new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]

        return tf.stop_gradient(new_mem)

    def _create_mask(self, qlen, mlen, dtype=tf.float32, same_length=False):
        r"""Create causal attention mask."""
        attn_mask = tf.ones([qlen, qlen], dtype=dtype)
        mask_u = tf.matrix_band_part(attn_mask, 0, -1)
        mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
        attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
        ret = tf.concat([attn_mask_pad, mask_u - mask_dia], axis=1)
        if same_length:
            mask_l = tf.matrix_band_part(attn_mask, -1, 0)
            ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]],
                            axis=1)

        return ret

    def _build(self,
               token_ids,
               segment_ids=None,
               input_mask=None,
               memory=None,
               permute_mask=None,
               target_mapping=None,
               bi_data=False,
               clamp_len=None,
               cache_len=0,
               same_length=False,
               attn_type='bi',
               two_stream=False,
               mode=None):
        r"""Compute XLNet representations for the input.

        Args:
            token_ids: Shape `[batch_size, max_time]`.
            segment_ids: Shape `[batch_size, max_time]`.
            input_mask: Float tensor of shape `[batch_size, max_time]`. Note that
                positions with value 1 are masked out.
            memory: Memory from previous batches. A list of length `num_layers`,
                each tensor of shape `[batch_size, mem_len, hidden_dim]`.
            permute_mask: The permutation mask. Float tensor of shape
                `[batch_size, max_time, max_time]`.
                A value of 0 for ``permute_mask[i, j, k]`` indicates that
                position `i` attends to position `j` in batch `k`.
            target_mapping: The target token mapping. Float tensor of shape
                `[batch_size, num_targets, max_time]`.
                A value of 1 for ``target_mapping[i, j, k]`` indicates that
                the `i`-th target token (in order of permutation) in batch `k`
                is the token at position `j`.
                Each row ``target_mapping[i, :, k]`` can have no more than one
                value of 1.
            bi_data (bool): Whether to use bidirectional data input pipeline.
            clamp_len (int): Clamp all relative distances larger than
                :attr:`clamp_len`. A value of -1 means no clamping.
            cache_len (int): Length of memory (number of tokens) to cache.
            same_length (bool): Whether to use the same attention length for
                each token.
            attn_type (str): Attention type. Supported values are `"uni"`
                and `"bi"`.
            two_stream (bool): Whether to use two-stream attention. Only set to
                `True` when pre-training or generating text. Defaults to
                `False`.

        Returns: A tuple of `(output, new_memory)`:

            - **output**: The final layer output representations. Shape
              `[batch_size, max_time, hidden_dim]`.
            - **new_memory**: The memory of the current batch.
              If `cache_len` is 0, then `new_memory` is `None`. Otherwise, it is
              a list of length `num_layers`, each tensor of shape
              `[batch_size, cache_len, hidden_dim]`.
              This can be used as the :attr:`memory` argument in the next batch.
        """
        return self._execute(self.word_embedder(token_ids),
                             segment_ids=segment_ids,
                             input_mask=input_mask,
                             memory=memory,
                             permute_mask=permute_mask,
                             target_mapping=target_mapping,
                             bi_data=bi_data,
                             clamp_len=clamp_len,
                             cache_len=cache_len,
                             same_length=same_length,
                             attn_type=attn_type,
                             two_stream=two_stream,
                             mode=mode)

    def _execute(
            self,
            word_embed,
            segment_ids=None,  # noqa: C901
            input_mask=None,
            memory=None,
            permute_mask=None,
            target_mapping=None,
            bi_data=False,
            clamp_len=None,
            cache_len=0,
            same_length=False,
            attn_type='bi',
            two_stream=False,
            mode=None):
        r"""Compute XLNet representations for the input. This layer exists
        because :class:`XLNetDecoder` compute embeddings in the decoder helper.
        `word_embed` has shape `[batch_size, max_time, word_embed_dim]`.
        Please refer to :meth:`_build` for the detailed information of other
        arguments.
        """
        # word_embed: [max_time, batch_size, word_embed_dim]
        word_embed = tf.transpose(word_embed, perm=[1, 0, 2])
        # segment_ids: [max_time, batch_size]
        if segment_ids is not None:
            segment_ids = tf.transpose(segment_ids, perm=[1, 0])
        # input_mask: [max_time, batch_size]
        if input_mask is not None:
            input_mask = tf.transpose(input_mask, perm=[1, 0])
        # memory: A list of length num_layers
        # each tensor of shape [mem_len, batch_size, hidden_dim]
        if memory is not None:
            memory = [tf.transpose(m, perm=[1, 0, 2]) for m in memory]
        # permute_mask: [max_time, max_time, batch_size]
        if permute_mask is not None:
            permute_mask = tf.transpose(permute_mask, perm=[1, 2, 0])
        # target_mapping: [num_targets, max_time, batch_size]
        if target_mapping is not None:
            target_mapping = tf.transpose(target_mapping, perm=[1, 2, 0])

        max_time = tf.shape(word_embed)[0]
        batch_size = tf.shape(word_embed)[1]
        mem_len = tf.shape(memory[0])[0] if memory is not None else 0
        tot_len = max_time + mem_len
        reuse_len = self._hparams.reuse_len
        is_training = is_train_mode(mode)

        # Attention mask
        # causal attention mask
        if attn_type == 'uni':
            attn_mask = self._create_mask(max_time, mem_len, tf.float32,
                                          same_length)
            attn_mask = attn_mask[:, :, None, None]
        elif attn_type == 'bi':
            attn_mask = None
        else:
            raise ValueError(
                'Unsupported attention type: {}'.format(attn_type))

        # data mask: input mask & perm mask
        if input_mask is not None and permute_mask is not None:
            data_mask = input_mask[None] + permute_mask
        elif input_mask is not None and permute_mask is None:
            data_mask = input_mask[None]
        elif input_mask is None and permute_mask is not None:
            data_mask = permute_mask
        else:
            data_mask = None

        if data_mask is not None:
            # all mems can be attended to
            mems_mask = tf.zeros([tf.shape(data_mask)[0], mem_len, batch_size],
                                 dtype=tf.float32)
            data_mask = tf.concat([mems_mask, data_mask], 1)
            if attn_mask is None:
                attn_mask = data_mask[:, :, :, None]
            else:
                attn_mask += data_mask[:, :, :, None]

        if attn_mask is not None:
            attn_mask = tf.cast(attn_mask > 0, dtype=tf.float32)

        if attn_mask is not None:
            non_tgt_mask = -tf.eye(max_time, dtype=tf.float32)
            non_tgt_mask = tf.concat([
                tf.zeros([max_time, mem_len], dtype=tf.float32), non_tgt_mask
            ],
                                     axis=-1)
            non_tgt_mask = tf.cast(
                (attn_mask + non_tgt_mask[:, :, None, None]) > 0,
                dtype=tf.float32)
        else:
            non_tgt_mask = None

        # Segment embedding
        if segment_ids is not None:
            mem_pad = tf.zeros([mem_len, batch_size], dtype=tf.int32)
            cat_ids = tf.concat([mem_pad, segment_ids], 0)
            segment_matrix = tf.cast(
                tf.logical_not(tf.equal(segment_ids[:, None],
                                        cat_ids[None, :])), tf.int32)
            segment_matrix = tf.one_hot(segment_matrix, 2, dtype=tf.float32)
        else:
            segment_matrix = None

        # Position embedding
        pos_embed = self.pos_embed(batch_size, max_time, tot_len, clamp_len,
                                   attn_type, bi_data)
        pos_embed = self.dropout(pos_embed, training=is_training)

        states_h = self.dropout(word_embed, training=is_training)

        if two_stream:
            if target_mapping is not None:
                word_embed_q = tf.tile(
                    self.mask_embed,
                    [tf.shape(target_mapping)[0], batch_size, 1])
            else:
                word_embed_q = word_embed
            states_g = self.dropout(word_embed_q)
        else:
            states_g = None

        new_memory = []
        num_layers = self._hparams.num_layers
        for i in range(num_layers):
            cur_memory = memory[i] if memory is not None else None
            if cache_len > 0:
                new_memory.append(
                    self._cache_mem(states_h, cur_memory, cache_len,
                                    reuse_len))
            states_h, states_g = self.attn_layers[i](
                states_h=states_h,
                pos_embed=pos_embed,
                states_g=states_g,
                segment_mat=segment_matrix,
                attn_mask_h=non_tgt_mask,
                attn_mask_g=attn_mask,
                target_mapping=None,
                memory=cur_memory,
                mode=mode)
            ff_layer = self.ff_layers[i]
            states_h = ff_layer(states_h, mode=mode)

            if states_g is not None:
                states_g = ff_layer(states_g, mode=mode)

        output = self.dropout(states_h if states_g is None else states_g,
                              training=is_training)

        # Now output: [max_time, batch_size, hidden_dim]
        # new_memory: None or A list of length num_layers,
        # each tensor of shape [cache_len, batch_size, hidden_dim]
        output = tf.transpose(output, perm=[1, 0, 2])
        if new_memory is not None:
            new_memory = [tf.transpose(m, perm=[1, 0, 2]) for m in new_memory]

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

            if self.pretrained_model_dir:
                xlnet_utils.init_xlnet_checkpoint(self.pretrained_model_dir,
                                                  self.variable_scope.name)

        if cache_len == 0:
            return output, None

        return output, new_memory