コード例 #1
0
    def __init__(
            self,
            memory,
            memory_sequence_length=None,
            cell=None,
            cell_dropout_mode=None,
            vocab_size=None,
            output_layer=None,
            #attention_layer=None, # TODO(zhiting): only valid for tf>=1.0
            cell_input_fn=None,
            hparams=None):
        RNNDecoderBase.__init__(self, cell, vocab_size, output_layer,
                                cell_dropout_mode, hparams)

        attn_hparams = self._hparams['attention']
        attn_kwargs = attn_hparams['kwargs'].todict()

        # Parse the 'probability_fn' argument
        if 'probability_fn' in attn_kwargs:
            prob_fn = attn_kwargs['probability_fn']
            if prob_fn is not None and not callable(prob_fn):
                prob_fn = utils.get_function(prob_fn, [
                    'tensorflow.nn', 'tensorflow.contrib.sparsemax',
                    'tensorflow.contrib.seq2seq'
                ])
            attn_kwargs['probability_fn'] = prob_fn

        attn_kwargs.update({
            "memory_sequence_length": memory_sequence_length,
            "memory": memory
        })
        self._attn_kwargs = attn_kwargs
        attn_modules = ['tensorflow.contrib.seq2seq', 'texar.tf.custom']
        # Use variable_scope to ensure all trainable variables created in
        # the attention mechanism are collected
        with tf.variable_scope(self.variable_scope):
            attention_mechanism = utils.check_or_get_instance(
                attn_hparams["type"],
                attn_kwargs,
                attn_modules,
                classtype=tf.contrib.seq2seq.AttentionMechanism)

        self._attn_cell_kwargs = {
            "attention_layer_size": attn_hparams["attention_layer_size"],
            "alignment_history": attn_hparams["alignment_history"],
            "output_attention": attn_hparams["output_attention"],
        }
        self._cell_input_fn = cell_input_fn
        # Use variable_scope to ensure all trainable variables created in
        # AttentionWrapper are collected
        with tf.variable_scope(self.variable_scope):
            #if attention_layer is not None:
            #    self._attn_cell_kwargs["attention_layer_size"] = None
            attn_cell = AttentionWrapper(
                self._cell,
                attention_mechanism,
                cell_input_fn=self._cell_input_fn,
                #attention_layer=attention_layer,
                **self._attn_cell_kwargs)
            self._cell = attn_cell
コード例 #2
0
 def __init__(self,
              cell=None,
              cell_dropout_mode=None,
              vocab_size=None,
              output_layer=None,
              hparams=None):
     RNNDecoderBase.__init__(self, cell, vocab_size, output_layer,
                             cell_dropout_mode, hparams)
コード例 #3
0
    def default_hparams():
        """Returns a dictionary of hyperparameters with default values.

        .. code-block:: python

            {
                "rnn_cell": default_rnn_cell_hparams(),
                "max_decoding_length_train": None,
                "max_decoding_length_infer": None,
                "helper_train": {
                    "type": "TrainingHelper",
                    "kwargs": {}
                }
                "helper_infer": {
                    "type": "SampleEmbeddingHelper",
                    "kwargs": {}
                }
                "name": "basic_rnn_decoder"
            }

        Here:

        "rnn_cell": dict
            A dictionary of RNN cell hyperparameters. Ignored if
            :attr:`cell` is given to the decoder constructor.
            The default value is defined in
            :func:`~texar.tf.core.default_rnn_cell_hparams`.

        "max_decoding_length_train": int or None
            Maximum allowed number of decoding steps in training mode.
            If `None` (default), decoding is
            performed until fully done, e.g., encountering the <EOS> token.
            Ignored if `max_decoding_length` is given when calling
            the decoder.

        "max_decoding_length_infer": int or None
            Same as "max_decoding_length_train" but for inference mode.

        "helper_train": dict
            The hyperparameters of the helper used in training.
            "type" can be a helper class, its name or module path, or a
            helper instance. If a class name is given, the class must be
            from module :tf_main:`tf.contrib.seq2seq <contrib/seq2seq>`,
            :mod:`texar.tf.modules`, or :mod:`texar.tf.custom`. This is used
            only when both `decoding_strategy` and `helper` augments are
            `None` when calling the decoder. See
            :meth:`~texar.tf.modules.RNNDecoderBase._build` for more details.

        "helper_infer": dict
            Same as "helper_train" but during inference mode.

        "name": str
            Name of the decoder.

            The default value is "basic_rnn_decoder".
        """
        hparams = RNNDecoderBase.default_hparams()
        hparams["name"] = "basic_rnn_decoder"
        return hparams
コード例 #4
0
    def default_hparams():
        """Returns a dictionary of hyperparameters with default values:

        Common hyperparameters are the same as in
        :class:`~texar.tf.modules.BasicRNNDecoder`.
        :meth:`~texar.tf.modules.BasicRNNDecoder.default_hparams`.
        Additional hyperparameters are for attention mechanism
        configuration.

        .. code-block:: python

            {
                "attention": {
                    "type": "LuongAttention",
                    "kwargs": {
                        "num_units": 256,
                    },
                    "attention_layer_size": None,
                    "alignment_history": False,
                    "output_attention": True,
                },
                # The following hyperparameters are the same as with
                # `BasicRNNDecoder`
                "rnn_cell": default_rnn_cell_hparams(),
                "max_decoding_length_train": None,
                "max_decoding_length_infer": None,
                "helper_train": {
                    "type": "TrainingHelper",
                    "kwargs": {}
                }
                "helper_infer": {
                    "type": "SampleEmbeddingHelper",
                    "kwargs": {}
                }
                "name": "attention_rnn_decoder"
            }

        Here:

        "attention": dict
            Attention hyperparameters, including:

            "type": str or class or instance
                The attention type. Can be an attention class, its name or
                module path, or a class instance. The class must be a subclass
                of :tf_main:`TF AttentionMechanism
                <contrib/seq2seq/AttentionMechanism>`. If class name is
                given, the class must be from modules
                :tf_main:`tf.contrib.seq2seq <contrib/seq2seq>` or
                :mod:`texar.tf.custom`.

                Example:

                    .. code-block:: python

                        # class name
                        "type": "LuongAttention"
                        "type": "BahdanauAttention"
                        # module path
                        "type": "tf.contrib.seq2seq.BahdanauMonotonicAttention"
                        "type": "my_module.MyAttentionMechanismClass"
                        # class
                        "type": tf.contrib.seq2seq.LuongMonotonicAttention
                        # instance
                        "type": LuongAttention(...)

            "kwargs": dict
                keyword arguments for the attention class constructor.
                Arguments :attr:`memory` and
                :attr:`memory_sequence_length` should **not** be
                specified here because they are given to the decoder
                constructor. Ignored if "type" is an attention class
                instance. For example

                Example:

                    .. code-block:: python

                        "type": "LuongAttention",
                        "kwargs": {
                            "num_units": 256,
                            "probability_fn": tf.nn.softmax
                        }

                    Here "probability_fn" can also be set to the string name
                    or module path to a probability function.

                "attention_layer_size": int or None
                    The depth of the attention (output) layer. The context and
                    cell output are fed into the attention layer to generate
                    attention at each time step.
                    If `None` (default), use the context as attention at each
                    time step.

                "alignment_history": bool
                    whether to store alignment history from all time steps
                    in the final output state. (Stored as a time major
                    `TensorArray` on which you must call `stack()`.)

                "output_attention": bool
                    If `True` (default), the output at each time step is
                    the attention value. This is the behavior of Luong-style
                    attention mechanisms. If `False`, the output at each
                    time step is the output of `cell`.  This is the
                    beahvior of Bhadanau-style attention mechanisms.
                    In both cases, the `attention` tensor is propagated to
                    the next time step via the state and is used there.
                    This flag only controls whether the attention mechanism
                    is propagated up to the next cell in an RNN stack or to
                    the top RNN output.
        """
        hparams = RNNDecoderBase.default_hparams()
        hparams["name"] = "attention_rnn_decoder"
        hparams["attention"] = {
            "type": "LuongAttention",
            "kwargs": {
                "num_units": 256,
            },
            "attention_layer_size": None,
            "alignment_history": False,
            "output_attention": True,
        }
        return hparams