Пример #1
0
def get_gradient_clip_fn(hparams=None):
    """Creates a gradient clipping function based on the hyperparameters.

    See the :attr:`gradient_clip` field in
    :meth:`~texar.core.optimization.default_optimization_hparams` for all
    hyperparameters and default values.

    The gradient clipping function takes a list of `(gradients, variables)`
    tuples and returns a list of `(clipped_gradients, variables)` tuples.
    Typical examples include
    :tf_main:`tf.clip_by_global_norm <clip_by_global_norm>`,
    :tf_main:`tf.clip_by_value <clip_by_value>`,
    :tf_main:`tf.clip_by_norm <clip_by_norm>`,
    :tf_main:`tf.clip_by_average_norm <clip_by_average_norm>`, etc.

    Args:
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically.

    Returns:
        function or `None`: If :attr:`hparams["type"]` is specified, returns
        the respective function. If :attr:`hparams["type"]` is empty,
        returns `None`.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(hparams,
                          default_optimization_hparams()["gradient_clip"])
    fn_type = hparams["type"]
    if fn_type is None or fn_type == "":
        return None

    fn_modules = ["tensorflow", "texar.custom"]
    clip_fn = utils.get_function(fn_type, fn_modules)
    clip_fn_args = inspect.getargspec(clip_fn).args
    fn_kwargs = hparams["kwargs"]
    if isinstance(fn_kwargs, HParams):
        fn_kwargs = fn_kwargs.todict()

    def grad_clip_fn(grads_and_vars):
        """Gradient clipping function.

        Args:
            grads_and_vars (list): A list of `(gradients, variables)` tuples.

        Returns:
            list: A list of `(clipped_gradients, variables)` tuples.
        """
        grads, vars_ = zip(*grads_and_vars)
        if clip_fn == tf.clip_by_global_norm:
            clipped_grads, _ = clip_fn(t_list=grads, **fn_kwargs)
        elif 't_list' in clip_fn_args:
            clipped_grads = clip_fn(t_list=grads, **fn_kwargs)
        elif 't' in clip_fn_args:  # e.g., tf.clip_by_value
            clipped_grads = [clip_fn(t=grad, **fn_kwargs) for grad in grads]

        return list(zip(clipped_grads, vars_))

    return grad_clip_fn
Пример #2
0
 def _make_bucket_length_fn(self):
     length_fn = self._hparams.bucket_length_fn
     if not length_fn:
         length_fn = lambda x: tf.maximum(x[self.source_length_name], x[
             self.target_length_name])
     elif not is_callable(length_fn):
         # pylint: disable=redefined-variable-type
         length_fn = utils.get_function(length_fn, ["texar.custom"])
     return length_fn
Пример #3
0
def get_initializer(hparams=None):
    """Returns an initializer instance.

    .. role:: python(code)
       :language: python

    Args:
        hparams (dict or HParams, optional): Hyperparameters with the structure

            .. code-block:: python

                {
                    "type": "initializer_class_or_function",
                    "kwargs": {
                        #...
                    }
                }

            The "type" field can be a initializer class, its name or module
            path, or class instance. If class name is provided, the class must
            be from one the following modules:
            :tf_main:`tf.initializers <initializers>`,
            :tf_main:`tf.keras.initializers <keras/initializers>`,
            :tf_main:`tf < >`, and :mod:`texar.custom`. The class is created
            by :python:`initializer_class(**kwargs)`. If a class instance
            is given, "kwargs" is ignored and can be omitted.

            Besides, the "type" field can also be an initialization function
            called with :python:`initialization_fn(**kwargs)`. In this case
            "type" can be the function, or its name or module path. If
            function name is provided, the function must be from one of the
            above modules or module `tf.contrib.layers`. If no
            keyword argument is required, "kwargs" can be omitted.

    Returns:
        An initializer instance. `None` if :attr:`hparams` is `None`.
    """
    if hparams is None:
        return None

    kwargs = hparams.get("kwargs", {})
    if isinstance(kwargs, HParams):
        kwargs = kwargs.todict()
    modules = [
        "tensorflow.initializers", "tensorflow.keras.initializers",
        "tensorflow", "texar.custom"
    ]
    try:
        initializer = utils.check_or_get_instance(hparams["type"], kwargs,
                                                  modules)
    except TypeError:
        modules += ['tensorflow.contrib.layers']
        initializer_fn = utils.get_function(hparams["type"], modules)
        initializer = initializer_fn(**kwargs)

    return initializer
Пример #4
0
    def __init__(self,
                 memory,
                 memory_sequence_length=None,
                 cell_input_fn=None,
                 cell=None,
                 cell_dropout_mode=None,
                 vocab_size=None,
                 output_layer=None,
                 hparams=None):
        RNNDecoderBase.__init__(self, cell, vocab_size, output_layer,
                                cell_dropout_mode, hparams)

        attn_hparams = self._hparams['attention']
        attn_kwargs = attn_hparams['kwargs'].todict()

        # Parse the 'probability_fn' argument
        if 'probability_fn' in attn_kwargs:
            prob_fn = attn_kwargs['probability_fn']
            if prob_fn is not None and not callable(prob_fn):
                prob_fn = utils.get_function(prob_fn, [
                    'tensorflow.nn', 'tensorflow.contrib.sparsemax',
                    'tensorflow.contrib.seq2seq'
                ])
            attn_kwargs['probability_fn'] = prob_fn

        attn_kwargs.update({
            "memory_sequence_length": memory_sequence_length,
            "memory": memory
        })
        self._attn_kwargs = attn_kwargs
        attn_modules = ['tensorflow.contrib.seq2seq', 'texar.custom']
        # Use variable_scope to ensure all trainable variables created in
        # the attention mechanism are collected
        with tf.variable_scope(self.variable_scope):
            attention_mechanism = utils.check_or_get_instance(
                attn_hparams["type"],
                attn_kwargs,
                attn_modules,
                classtype=tf.contrib.seq2seq.AttentionMechanism)

        self._attn_cell_kwargs = {
            "attention_layer_size": attn_hparams["attention_layer_size"],
            "alignment_history": attn_hparams["alignment_history"],
            "output_attention": attn_hparams["output_attention"],
        }
        self._cell_input_fn = cell_input_fn
        # Use variable_scope to ensure all trainable variables created in
        # AttentionWrapper are collected
        with tf.variable_scope(self.variable_scope):
            attn_cell = AttentionWrapper(self._cell,
                                         attention_mechanism,
                                         cell_input_fn=self._cell_input_fn,
                                         **self._attn_cell_kwargs)
            self._cell = attn_cell
Пример #5
0
def get_activation_fn(fn_name: Optional[Union[str,
                                              Callable[[torch.Tensor],
                                                       torch.Tensor]]] = None,
                      kwargs: Union[HParams, Dict, None] = None) \
        -> Optional[Callable[[torch.Tensor], torch.Tensor]]:
    r"""Returns an activation function `fn` with the signature
    `output = fn(input)`.

    If the function specified by :attr:`fn_name` has more than one arguments
    without default values, then all these arguments except the input feature
    argument must be specified in :attr:`kwargs`. Arguments with default values
    can also be specified in :attr:`kwargs` to take values other than the
    defaults. In this case a partial function is returned with the above
    signature.

    Args:
        fn_name (str or callable): An activation function, or its name or
            module path. The function can be:

            - Built-in function defined in
              :torch_docs:`torch.nn.functional<nn.html#torch-nn-functional>`
            - User-defined activation functions in module :mod:`texar.custom`.
            - External activation functions. Must provide the full module path,
              e.g., ``"my_module.my_activation_fn"``.

        kwargs (optional): A `dict` or instance of :class:`~texar.HParams`
            containing the keyword arguments of the activation function.

    Returns:
        An activation function. `None` if :attr:`fn_name` is `None`.
    """
    if fn_name is None:
        return None

    fn_modules = [
        'torch', 'torch.nn.functional', 'texar.custom', 'texar.core.layers'
    ]
    activation_fn_ = utils.get_function(fn_name, fn_modules)
    activation_fn = activation_fn_

    # Make a partial function if necessary
    if kwargs is not None:
        if isinstance(kwargs, HParams):
            kwargs = kwargs.todict()

        def _partial_fn(features):
            return activation_fn_(features, **kwargs)

        activation_fn = _partial_fn

    return activation_fn
Пример #6
0
 def _make_bucket_length_fn(self):
     length_fn = self._hparams.bucket_length_fn
     if not length_fn:
         # Uses the length of the first text data
         i = -1
         for i, hparams_i in enumerate(self._hparams.datasets):
             if _is_text_data(hparams_i["data_type"]):
                 break
         if i < 0:
             raise ValueError("Undefined `length_fn`.")
         length_fn = lambda x: x[self.length_name(i)]
     elif not is_callable(length_fn):
         # pylint: disable=redefined-variable-type
         length_fn = utils.get_function(length_fn, ["texar.custom"])
     return length_fn
Пример #7
0
def get_activation_fn(fn_name="identity", kwargs=None):
    """Returns an activation function `fn` with the signature
    `output = fn(input)`.

    If the function specified by :attr:`fn_name` has more than one arguments
    without default values, then all these arguments except the input feature
    argument must be specified in :attr:`kwargs`. Arguments with default values
    can also be specified in :attr:`kwargs` to take values other than the
    defaults.

    Args:
        fn_name (str or callable): The name or full path to an activation
            function, or the function itself.

            The function can be:

            - Built-in function defined in :mod:`tf` or \
              :mod:`tf.nn`, e.g., :tf_main:`identity <identity>`.
            - User-defined activation functions in `texar.custom`.
            - External activation functions. Must provide the full path, \
              e.g., "my_module.my_activation_fn".

            If a callable is provided, then it is returned directly.

        kwargs (optional): A `dict` or instance of :class:`~texar.HParams`
            containing the keyword arguments of the activation function.

    Returns:
        The activation function. `None` if :attr:`fn_name` is `None`.
    """
    if fn_name is None:
        return None

    fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom']
    activation_fn_ = utils.get_function(fn_name, fn_modules)
    activation_fn = activation_fn_

    # Make a partial function if necessary
    if kwargs is not None:
        if isinstance(kwargs, HParams):
            kwargs = kwargs.todict()

        def _partial_fn(features):
            return activation_fn_(features, **kwargs)

        activation_fn = _partial_fn

    return activation_fn
Пример #8
0
    def _make_other_transformations(other_trans_hparams, data_spec):
        """Creates a list of tranformation functions based on the
        hyperparameters.

        Args:
            other_trans_hparams (list): A list of transformation functions,
                names, or full paths.
            data_spec: An instance of :class:`texar.data._DataSpec` to
                be passed to transformation functions.

        Returns:
            A list of transformation functions.
        """
        other_trans = []
        for tran in other_trans_hparams:
            if not is_callable(tran):
                tran = utils.get_function(tran, ["texar.custom"])
            other_trans.append(dsutils.make_partial(tran, data_spec))
        return other_trans
Пример #9
0
def get_initializer(hparams: Optional[HParams] = None) \
        -> Optional[Callable[[torch.Tensor], torch.Tensor]]:
    r"""Returns an initializer instance.

    .. role:: python(code)
       :language: python

    Args:
        hparams (dict or HParams, optional): Hyperparameters with the structure

            .. code-block:: python

                {
                    "type": "initializer_class_or_function",
                    "kwargs": {
                        #...
                    }
                }

            The "type" field can be a function name or module path. If name is
            provided, it be must be from one the following modules:
            :torch_docs:`torch.nn.init <nn.html\#torch-nn-init>` and
            :mod:`texar.custom`.

            Besides, the "type" field can also be an initialization function
            called with :python:`initialization_fn(**kwargs)`. In this case
            "type" can be the function, or its name or module path. If no
            keyword argument is required, "kwargs" can be omitted.

    Returns:
        An initializer instance. `None` if :attr:`hparams` is `None`.
    """
    if hparams is None:
        return None

    kwargs = hparams.get('kwargs', {})
    if isinstance(kwargs, HParams):
        kwargs = kwargs.todict()
    modules = ['torch.nn.init', 'torch', 'texar.custom']
    initializer_fn = utils.get_function(hparams['type'], modules)
    initializer = functools.partial(initializer_fn, **kwargs)

    return initializer
Пример #10
0
def get_constraint_fn(fn_name="NonNeg"):
    """Returns a constraint function.

    .. role:: python(code)
       :language: python

    The function must follow the signature:
    :python:`w_ = constraint_fn(w)`.

    Args:
        fn_name (str or callable): The name or full path to a
            constraint function, or the function itself.

            The function can be:

            - Built-in constraint functions defined in modules \
            :tf_main:`tf.keras.constraints <keras/constraints>` \
            (e.g., :tf_main:`NonNeg <keras/constraints/NonNeg>`) \
            or :tf_main:`tf < >` or :tf_main:`tf.nn <nn>` \
            (e.g., activation functions).
            - User-defined function in :mod:`texar.custom`.
            - Externally defined function. Must provide the full path, \
            e.g., `"my_module.my_constraint_fn"`.

            If a callable is provided, then it is returned directly.

    Returns:
        The constraint function. `None` if :attr:`fn_name` is `None`.
    """
    if fn_name is None:
        return None

    fn_modules = [
        'tensorflow.keras.constraints', 'tensorflow', 'tensorflow.nn',
        'texar.custom'
    ]
    constraint_fn = utils.get_function(fn_name, fn_modules)
    return constraint_fn
Пример #11
0
    def __init__(self,
                 input_size: int,
                 encoder_output_size: int,
                 vocab_size: int,
                 cell: Optional[RNNCellBase] = None,
                 output_layer: Optional[Union[nn.Module, torch.Tensor]] = None,
                 cell_input_fn: Optional[Callable[[torch.Tensor, torch.Tensor],
                                                  torch.Tensor]] = None,
                 hparams: Optional[HParams] = None):

        super().__init__(cell=cell,
                         input_size=input_size,
                         vocab_size=vocab_size,
                         output_layer=output_layer,
                         hparams=hparams)

        attn_hparams = self._hparams['attention']
        attn_kwargs = attn_hparams['kwargs'].todict()

        # Parse the `probability_fn` argument
        if 'probability_fn' in attn_kwargs:
            prob_fn = attn_kwargs['probability_fn']
            if prob_fn is not None and not callable(prob_fn):
                prob_fn = get_function(prob_fn, ['torch.nn.functional',
                                                 'texar.core'])
            attn_kwargs['probability_fn'] = prob_fn

        # Parse `encoder_output_size` and `decoder_output_size` arguments
        if attn_hparams['type'] in ['BahdanauAttention',
                                    'BahdanauMonotonicAttention']:
            attn_kwargs.update({"decoder_output_size": self._cell.hidden_size})
        attn_kwargs.update({"encoder_output_size": encoder_output_size})

        attn_modules = ['texar.core']
        self.attention_mechanism: AttentionMechanism
        self.attention_mechanism = check_or_get_instance(
            attn_hparams["type"], attn_kwargs, attn_modules,
            classtype=AttentionMechanism)

        self._attn_cell_kwargs = {
            "attention_layer_size": attn_hparams["attention_layer_size"],
            "alignment_history": attn_hparams["alignment_history"],
            "output_attention": attn_hparams["output_attention"],
        }
        self._cell_input_fn = cell_input_fn

        if attn_hparams["output_attention"] and vocab_size is not None and \
                self.attention_mechanism is not None:
            if attn_hparams["attention_layer_size"] is None:
                self._output_layer = nn.Linear(
                    encoder_output_size,
                    vocab_size)
            else:
                self._output_layer = nn.Linear(
                    sum(attn_hparams["attention_layer_size"])
                    if isinstance(attn_hparams["attention_layer_size"], list)
                    else attn_hparams["attention_layer_size"],
                    vocab_size)

        attn_cell = AttentionWrapper(
            self._cell,
            self.attention_mechanism,
            cell_input_fn=self._cell_input_fn,
            **self._attn_cell_kwargs)

        self._cell: AttentionWrapper = attn_cell

        self.memory: Optional[torch.Tensor] = None
        self.memory_sequence_length: Optional[torch.LongTensor] = None
Пример #12
0
    def _build(self,
               distribution='MultivariateNormalDiag',
               distribution_kwargs=None,
               transform=True,
               num_samples=None):
        """Samples from a distribution and optionally performs transformation
        with an MLP layer.

        The inputs and outputs are the same as
        :class:`~texar.modules.ReparameterizedStochasticConnector` except that
        the distribution does not need to be reparameterizable, and gradient
        cannot be back-propagate through the samples.

        Args:
            distribution: A instance of subclass of
                :tf_main:`TF Distribution <distributions/Distribution>`,
                or :tf_hmpg:`tensorflow_probability Distribution <probability>`.
                Can be a class, its name or module path, or a class instance.
            distribution_kwargs (dict, optional): Keyword arguments for the
                distribution constructor. Ignored if `distribution` is a
                class instance.
            transform (bool): Whether to perform MLP transformation of the
                distribution samples. If `False`, the structure/shape of a
                sample must match :attr:`output_size`.
            num_samples (optional): An `int` or `int` Tensor. Number of samples
                to generate. If not given, generate a single sample. Note
                that if batch size has already been included in
                `distribution`'s dimensionality, `num_samples` should be
                left as `None`.

        Returns:
            A tuple (output, sample), where

            - output: A Tensor or a (nested) tuple of Tensors with the same \
            structure and size of :attr:`output_size`. The batch dimension \
            equals :attr:`num_samples` if specified, or is determined by the \
            distribution dimensionality. If :attr:`transform` is `False`, \
            :attr:`output` will be equal to :attr:`sample`.
            - sample: The sample from the distribution, prior to transformation.

        Raises:
            ValueError: The output does not match :attr:`output_size`.
        """
        dstr = check_or_get_instance(distribution, distribution_kwargs, [
            "tensorflow.distributions", "tensorflow_probability.distributions",
            "tensorflow.contrib.distributions", "texar.custom"
        ])

        if num_samples:
            sample = dstr.sample(num_samples)
        else:
            sample = dstr.sample()

        if dstr.event_shape == []:
            sample = tf.reshape(sample,
                                sample.shape.concatenate(tf.TensorShape(1)))

        # Disable gradients through samples
        sample = tf.stop_gradient(sample)

        sample = tf.cast(sample, tf.float32)

        if transform:
            fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom']
            activation_fn = get_function(self.hparams.activation_fn,
                                         fn_modules)
            output = _mlp_transform(sample, self._output_size, activation_fn)
        else:
            output = sample

        _assert_same_size(output, self._output_size)
        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return output, sample
Пример #13
0
def get_learning_rate_decay_fn(hparams=None):
    """Creates learning rate decay function based on the hyperparameters.

    See the :attr:`learning_rate_decay` field in
    :meth:`~texar.core.default_optimization_hparams` for all
    hyperparameters and default values.

    Args:
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically.

    Returns:
        function or None: If hparams["type"] is specified, returns a
        function that takes `(learning_rate, step, **kwargs)` and
        returns a decayed learning rate. If
        hparams["type"] is empty, returns `None`.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(
            hparams,
            default_optimization_hparams()["learning_rate_decay"])

    fn_type = hparams["type"]
    if fn_type is None or fn_type == "":
        return None

    fn_modules = ["tensorflow.train", "texar.custom"]
    decay_fn = utils.get_function(fn_type, fn_modules)
    fn_kwargs = hparams["kwargs"]
    if fn_kwargs is HParams:
        fn_kwargs = fn_kwargs.todict()

    start_step = tf.cast(hparams["start_decay_step"], tf.int32)
    end_step = tf.cast(hparams["end_decay_step"], tf.int32)

    def lr_decay_fn(learning_rate, global_step):
        """Learning rate decay function.

        Args:
            learning_rate (float or Tensor): The original learning rate.
            global_step (int or scalar int Tensor): optimization step counter.

        Returns:
            scalar float Tensor: decayed learning rate.
        """
        offset_global_step = tf.maximum(
            tf.minimum(tf.cast(global_step, tf.int32), end_step) - start_step,
            0)
        if decay_fn == tf.train.piecewise_constant:
            decayed_lr = decay_fn(x=offset_global_step, **fn_kwargs)
        else:
            fn_kwargs_ = {
                "learning_rate": learning_rate,
                "global_step": offset_global_step
            }
            fn_kwargs_.update(fn_kwargs)
            decayed_lr = utils.call_function_with_redundant_kwargs(
                decay_fn, fn_kwargs_)

            decayed_lr = tf.maximum(decayed_lr, hparams["min_learning_rate"])

        return decayed_lr

    return lr_decay_fn
Пример #14
0
    def _build(self,
               distribution='MultivariateNormalDiag',
               distribution_kwargs=None,
               transform=True,
               num_samples=None):
        """Samples from a distribution and optionally performs transformation
        with an MLP layer.

        The distribution must be reparameterizable, i.e.,
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED`.

        Args:
            distribution: A
                :tf_main:`TF Distribution <contrib/distributions/Distribution>`.
                Can be a class, its name or module path, or an instance of
                a subclass.
            distribution_kwargs (dict, optional): Keyword arguments for the
                distribution constructor. Ignored if `distribution` is a
                class instance.
            transform (bool): Whether to perform MLP transformation of the
                distribution samples. If `False`, the structure/shape of a
                sample must match :attr:`output_size`.
            num_samples (optional): An `int` or `int` Tensor. Number of samples
                to generate. If not given, generate a single sample. Note
                that if batch size has already been included in
                `distribution`'s dimensionality, `num_samples` should be
                left as `None`.

        Returns:
            A tuple (output, sample), where

            - output: A Tensor or a (nested) tuple of Tensors with the same \
            structure and size of :attr:`output_size`. The batch dimension \
            equals :attr:`num_samples` if specified, or is determined by the \
            distribution dimensionality.
            - sample: The sample from the distribution, prior to transformation.

        Raises:
            ValueError: If distribution cannot be reparametrized.
            ValueError: The output does not match :attr:`output_size`.
        """
        dstr = check_or_get_instance(
            distribution, distribution_kwargs,
            ["tensorflow.contrib.distributions", "texar.custom"])

        if dstr.reparameterization_type == tf_dstr.NOT_REPARAMETERIZED:
            raise ValueError(
                "Distribution is not reparameterized: %s" % dstr.name)

        if num_samples:
            sample = dstr.sample(num_samples)
        else:
            sample = dstr.sample()

        #if dstr.event_shape == []:
        #    sample = tf.reshape(
        #        sample,
        #        sample.shape.concatenate(tf.TensorShape(1)))

        # sample = tf.cast(sample, tf.float32)
        if transform:
            fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom']
            activation_fn = get_function(self.hparams.activation_fn, fn_modules)
            output = _mlp_transform(sample, self._output_size, activation_fn)

        _assert_same_size(output, self._output_size)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return output, sample
Пример #15
0
    def _build(self,
               distribution=None,
               distribution_type='MultivariateNormalDiag',
               distribution_kwargs=None,
               transform=False,
               num_samples=None):
        """Samples from a distribution and optionally performs transformation.

        Gradients would not propagate through the random samples.

        Args:
            distribution (optional): An instance of
                :class:`~tensorflow.contrib.distributions.Distribution`. If
                `None` (default), distribution is constructed based on
                :attr:`distribution_type`
            distribution_type (str, optional): Name or path to the distribution
                class which inherits
                :class:`~tensorflow.contrib.distributions.Distribution`. Ignored
                if :attr:`distribution` is specified.
            distribution_kwargs (dict, optional): Keyword arguments of the
                distribution class specified in :attr:`distribution_type`.
            transform (bool): Whether to perform MLP transformation of the
                samples. If `False`, the shape of a sample must match the
                :attr:`output_size`.
            num_samples (int or scalar int Tensor, optional): Number of samples
                to generate. `None` is required in training stage.

        Returns:
            If `num_samples`==None, returns a Tensor of shape `[batch_size x
            output_size]`, else returns a Tensor of shape `[num_samples x
            output_size]`. `num_samples` should be specified if not in
            training stage.

        Raises:
            ValueError: The output does not match the :attr:`output_size`.
        """
        if distribution:
            dstr = distribution
        elif distribution_type and distribution_kwargs:
            dstr = get_instance(
                distribution_type, distribution_kwargs,
                ["texar.custom", "tensorflow.contrib.distributions"])

        if num_samples:
            output = dstr.sample(num_samples)
        else:
            output = dstr.sample()

        if dstr.event_shape == []:
            output = tf.reshape(output,
                                output.shape.concatenate(tf.TensorShape(1)))

        # Disable gradients through samples
        output = tf.stop_gradient(output)

        output = tf.cast(output, tf.float32)

        if transform:
            fn_modules = ['texar.custom', 'tensorflow', 'tensorflow.nn']
            activation_fn = get_function(self.hparams.activation_fn,
                                         fn_modules)
            output = _mlp_transform(output, self._output_size, activation_fn)
        _assert_same_size(output, self._output_size)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return output
Пример #16
0
    def _build(self,
               distribution=None,
               distribution_type='MultivariateNormalDiag',
               distribution_kwargs=None,
               transform=True,
               num_samples=None):
        """Samples from a distribution and optionally performs transformation.

        The distribution must be reparameterizable, i.e.,
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED`.

        Args:
            distribution (optional): An instance of
                :class:`~tensorflow.contrib.distributions.Distribution`. If
                `None` (default), distribution is constructed based on
                :attr:`distribution_type` or
                :attr:`hparams['distribution']['type']`.
            distribution_type (str, optional): Name or path to the distribution
                class which inherits
                :class:`~tensorflow.contrib.distributions.Distribution`. Ignored
                if :attr:`distribution` is specified.
            distribution_kwargs (dict, optional): Keyword arguments of the
                distribution class specified in :attr:`distribution_type`.
            transform (bool): Whether to perform MLP transformation of the
                samples. If `False`, the shape of a sample must match the
                :attr:`output_size`.
            num_samples (int or scalar int Tensor, optional): Number of samples
                to generate. `None` is required in training stage.

        Returns:
            output: If `num_samples`==None, returns a Tensor of shape
                `[batch_size x output_size]`, else returns a Tensor of shape
                `[num_samples x output_size]`. `num_samples` should be specified
                if not in training stage.
            latent_z: The latent sampled z

        Raises:
            ValueError: If distribution cannot be reparametrized.
            ValueError: The output does not match the :attr:`output_size`.
        """
        if distribution:
            dstr = distribution
        elif distribution_type and distribution_kwargs:
            dstr = get_instance(
                distribution_type, distribution_kwargs,
                ["texar.custom", "tensorflow.contrib.distributions"])

        if dstr.reparameterization_type == tf_dstr.NOT_REPARAMETERIZED:
            raise ValueError("Distribution is not reparameterized: %s" %
                             dstr.name)

        if num_samples:
            latent_z = dstr.sample(num_samples)
        else:
            latent_z = dstr.sample()

        #if dstr.event_shape == []:
        #    latent_z = tf.reshape(
        #        latent_z,
        #        latent_z.shape.concatenate(tf.TensorShape(1)))

        # latent_z = tf.cast(latent_z, tf.float32)
        if transform:
            fn_modules = ['texar.custom', 'tensorflow', 'tensorflow.nn']
            activation_fn = get_function(self.hparams.activation_fn,
                                         fn_modules)
            output = _mlp_transform(latent_z, self._output_size, activation_fn)
        _assert_same_size(output, self._output_size)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return output, latent_z