def __init__(self, encoder_major=None, encoder_minor=None, hparams=None):
        EncoderBase.__init__(self, hparams)

        encoder_major_hparams = utils.get_instance_kwargs(
            None, self._hparams.encoder_major_hparams)
        encoder_minor_hparams = utils.get_instance_kwargs(
            None, self._hparams.encoder_minor_hparams)

        if encoder_major is not None:
            self._encoder_major = encoder_major
        else:
            with tf.variable_scope(self.variable_scope.name):
                with tf.variable_scope('encoder_major'):
                    self._encoder_major = utils.check_or_get_instance(
                        self._hparams.encoder_major_type,
                        encoder_major_hparams,
                        ['texar.modules.encoders', 'texar.custom'])

        if encoder_minor is not None:
            self._encoder_minor = encoder_minor
        elif self._hparams.config_share:
            with tf.variable_scope(self.variable_scope.name):
                with tf.variable_scope('encoder_minor'):
                    self._encoder_minor = utils.check_or_get_instance(
                        self._hparams.encoder_major_type,
                        encoder_major_hparams,
                        ['texar.modules.encoders', 'texar.custom'])
        else:
            with tf.variable_scope(self.variable_scope.name):
                with tf.variable_scope('encoder_minor'):
                    self._encoder_minor = utils.check_or_get_instance(
                        self._hparams.encoder_minor_type,
                        encoder_minor_hparams,
                        ['texar.modules.encoders', 'texar.custom'])
示例#2
0
    def __init__(
            self,
            memory,
            memory_sequence_length=None,
            cell=None,
            cell_dropout_mode=None,
            vocab_size=None,
            output_layer=None,
            #attention_layer=None, # TODO(zhiting): only valid for tf>=1.0
            cell_input_fn=None,
            hparams=None):
        RNNDecoderBase.__init__(self, cell, vocab_size, output_layer,
                                cell_dropout_mode, hparams)

        attn_hparams = self._hparams['attention']
        attn_kwargs = attn_hparams['kwargs'].todict()

        # Parse the 'probability_fn' argument
        if 'probability_fn' in attn_kwargs:
            prob_fn = attn_kwargs['probability_fn']
            if prob_fn is not None and not callable(prob_fn):
                prob_fn = utils.get_function(prob_fn, [
                    'tensorflow.nn', 'tensorflow.contrib.sparsemax',
                    'tensorflow.contrib.seq2seq'
                ])
            attn_kwargs['probability_fn'] = prob_fn

        attn_kwargs.update({
            "memory_sequence_length": memory_sequence_length,
            "memory": memory
        })
        self._attn_kwargs = attn_kwargs
        attn_modules = ['tensorflow.contrib.seq2seq', 'texar.custom']
        # Use variable_scope to ensure all trainable variables created in
        # the attention mechanism are collected
        with tf.variable_scope(self.variable_scope):
            attention_mechanism = utils.check_or_get_instance(
                attn_hparams["type"],
                attn_kwargs,
                attn_modules,
                classtype=tf.contrib.seq2seq.AttentionMechanism)

        self._attn_cell_kwargs = {
            "attention_layer_size": attn_hparams["attention_layer_size"],
            "alignment_history": attn_hparams["alignment_history"],
            "output_attention": attn_hparams["output_attention"],
        }
        self._cell_input_fn = cell_input_fn
        # Use variable_scope to ensure all trainable variables created in
        # AttentionWrapper are collected
        with tf.variable_scope(self.variable_scope):
            #if attention_layer is not None:
            #    self._attn_cell_kwargs["attention_layer_size"] = None
            attn_cell = AttentionWrapper(
                self._cell,
                attention_mechanism,
                cell_input_fn=self._cell_input_fn,
                #attention_layer=attention_layer,
                **self._attn_cell_kwargs)
            self._cell = attn_cell
示例#3
0
 def _build_decoder(self):
     kwargs = {
         "vocab_size": self._tgt_vocab.size,
         "hparams": self._hparams.decoder_hparams.todict()
     }
     self._decoder = utils.check_or_get_instance(
         self._hparams.decoder, kwargs, ["texar.modules", "texar.custom"])
示例#4
0
    def _get_beam_search_cell(self, beam_width):
        """Returns the RNN cell for beam search decoding.
        """
        with tf.variable_scope(self.variable_scope, reuse=True):
            attn_kwargs = copy.copy(self._attn_kwargs)

            memory = attn_kwargs['memory']
            attn_kwargs['memory'] = tile_batch(memory, multiplier=beam_width)

            memory_seq_length = attn_kwargs['memory_sequence_length']
            if memory_seq_length is not None:
                attn_kwargs['memory_sequence_length'] = tile_batch(
                    memory_seq_length, beam_width)

            attn_modules = ['tensorflow.contrib.seq2seq', 'texar.custom']
            bs_attention_mechanism = utils.check_or_get_instance(
                self._hparams.attention.type,
                attn_kwargs,
                attn_modules,
                classtype=tf.contrib.seq2seq.AttentionMechanism)

            bs_attn_cell = AttentionWrapper(self._cell._cell,
                                            bs_attention_mechanism,
                                            cell_input_fn=self._cell_input_fn,
                                            **self._attn_cell_kwargs)

            self._beam_search_cell = bs_attn_cell

            return bs_attn_cell
示例#5
0
 def _build_encoder(self):
     kwargs = {
         "hparams": self._hparams.encoder_hparams.todict()
     }
     self._encoder = utils.check_or_get_instance(
         self._hparams.encoder, kwargs,
         ["texar.modules", "texar.custom"])
示例#6
0
 def _build_connector(self):
     kwargs = {
         "output_size": self._decoder.state_size,
         "hparams": self._hparams.connector_hparams.todict()
     }
     self._connector = utils.check_or_get_instance(
         self._hparams.connector, kwargs, ["texar.modules", "texar.custom"])
示例#7
0
    def __init__(self,
                 env_config,
                 sess=None,
                 policy=None,
                 policy_kwargs=None,
                 policy_caller_kwargs=None,
                 learning_rate=None,
                 hparams=None):
        EpisodicAgentBase.__init__(self, env_config, hparams)

        self._sess = sess
        self._lr = learning_rate
        self._discount_factor = self._hparams.discount_factor

        with tf.variable_scope(self.variable_scope):
            if policy is None:
                kwargs = utils.get_instance_kwargs(
                    policy_kwargs, self._hparams.policy_hparams)
                policy = utils.check_or_get_instance(
                    self._hparams.policy_type,
                    kwargs,
                    module_paths=['texar.modules', 'texar.custom'])
            self._policy = policy
            self._policy_caller_kwargs = policy_caller_kwargs or {}

        self._observs = []
        self._actions = []
        self._rewards = []

        self._train_outputs = None

        self._build_graph()
示例#8
0
    def __init__(self,
                 output_size: OutputSize,
                 mlp_input_size: Union[torch.Size, MaybeTuple[int], int],
                 distribution: Union[Distribution, str] = 'MultivariateNormal',
                 distribution_kwargs: Optional[Dict[str, Any]] = None,
                 hparams: Optional[HParams] = None):
        super().__init__(output_size, hparams)
        if distribution_kwargs is None:
            distribution_kwargs = {}
        self._dstr_kwargs = distribution_kwargs
        if isinstance(distribution, str):
            self._dstr: Distribution = utils.check_or_get_instance(
                distribution, self._dstr_kwargs,
                ["torch.distributions", "texar.custom"])
        else:
            self._dstr = distribution

        if self._dstr.has_rsample:
            raise ValueError("Distribution should not be reparameterizable")

        if isinstance(mlp_input_size, int):
            input_feature = mlp_input_size
        else:
            input_feature = np.prod(mlp_input_size)
        self._linear_layer = nn.Linear(
            input_feature, _sum_output_size(output_size))

        self._activation_fn = get_activation_fn(
            self.hparams.activation_fn)
示例#9
0
def get_initializer(hparams=None):
    """Returns an initializer instance.

    .. role:: python(code)
       :language: python

    Args:
        hparams (dict or HParams, optional): Hyperparameters with the structure

            .. code-block:: python

                {
                    "type": "initializer_class_or_function",
                    "kwargs": {
                        #...
                    }
                }

            The "type" field can be a initializer class, its name or module
            path, or class instance. If class name is provided, the class must
            be from one the following modules:
            :tf_main:`tf.initializers <initializers>`,
            :tf_main:`tf.keras.initializers <keras/initializers>`,
            :tf_main:`tf < >`, and :mod:`texar.custom`. The class is created
            by :python:`initializer_class(**kwargs)`. If a class instance
            is given, "kwargs" is ignored and can be omitted.

            Besides, the "type" field can also be an initialization function
            called with :python:`initialization_fn(**kwargs)`. In this case
            "type" can be the function, or its name or module path. If
            function name is provided, the function must be from one of the
            above modules or module `tf.contrib.layers`. If no
            keyword argument is required, "kwargs" can be omitted.

    Returns:
        An initializer instance. `None` if :attr:`hparams` is `None`.
    """
    if hparams is None:
        return None

    kwargs = hparams.get("kwargs", {})
    if isinstance(kwargs, HParams):
        kwargs = kwargs.todict()
    modules = [
        "tensorflow.initializers", "tensorflow.keras.initializers",
        "tensorflow", "texar.custom"
    ]
    try:
        initializer = utils.check_or_get_instance(hparams["type"], kwargs,
                                                  modules)
    except TypeError:
        modules += ['tensorflow.contrib.layers']
        initializer_fn = utils.get_function(hparams["type"], modules)
        initializer = initializer_fn(**kwargs)

    return initializer
 def _build_network(self, network, kwargs):
     if network is not None:
         self._network = network
     else:
         kwargs = utils.get_instance_kwargs(kwargs,
                                            self._hparams.network_hparams)
         self._network = utils.check_or_get_instance(
             self._hparams.network_type,
             kwargs,
             module_paths=['texar.modules', 'texar.custom'])
示例#11
0
    def _build_embedders(self):
        kwargs = {
            "vocab_size": self._src_vocab.size,
            "hparams": self._hparams.source_embedder_hparams.todict()
        }
        self._src_embedder = utils.check_or_get_instance(
            self._hparams.source_embedder, kwargs,
            ["texar.modules", "texar.custom"])

        if self._hparams.embedder_share:
            self._tgt_embedder = self._src_embedder
        else:
            kwargs = {
                "vocab_size": self._tgt_vocab.size,
            }
            if self._hparams.embedder_hparams_share:
                kwargs["hparams"] = \
                        self._hparams.source_embedder_hparams.todict()
            else:
                kwargs["hparams"] = \
                        self._hparams.target_embedder_hparams.todict()
            self._tgt_embedder = utils.check_or_get_instance(
                self._hparams.target_embedder, kwargs,
                ["texar.modules", "texar.custom"])
def get_regularizer(hparams=None):
    r"""Returns a variable regularizer instance.

    See :func:`~texar.core.default_regularizer_hparams` for all
    hyperparameters and default values.

    The "type" field can be a subclass
    of :class:`~texar.core.regularizers.Regularizer`, its string name
    or module path, or a class instance.

    Args:
        hparams (dict or HParams, optional): Hyperparameters. Missing
            hyperparameters are set to default values.

    Returns:
        A :class:`~texar.core.regularizers.Regularizer` instance.
        `None` if :attr:`hparams` is `None` or taking the default
        hyperparameter value.

    Raises:
        ValueError: The resulting regularizer is not an instance of
            :class:`~texar.core.regularizers.Regularizer`.
    """

    if hparams is None:
        return None

    if isinstance(hparams, dict):
        hparams = HParams(hparams, default_regularizer_hparams())

    rgl = utils.check_or_get_instance(
        hparams.type, hparams.kwargs.todict(),
        ["texar.core.regularizers", "texar.custom"])

    if not isinstance(rgl, Regularizer):
        raise ValueError("The regularizer must be an instance of "
                         "texar.core.regularizers.Regularizer.")

    if isinstance(rgl, L1L2) and rgl.l1 == 0. and rgl.l2 == 0.:
        return None

    return rgl
def get_rnn_cell(input_size, hparams=None):
    r"""Creates an RNN cell.

    See :func:`~texar.core.default_rnn_cell_hparams` for all
    hyperparameters and default values.

    Args:
        input_size (int): Size of the input to the cell in the first layer.
        hparams (dict or HParams, optional): Cell hyperparameters. Missing
            hyperparameters are set to default values.

    Returns:
        A cell instance.

    Raises:
        ValueError: If ``hparams["num_layers"]``>1 and ``hparams["type"]`` is a
            class instance.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(hparams, default_rnn_cell_hparams())

    d_hp = hparams['dropout']
    variational_recurrent = d_hp['variational_recurrent']
    input_keep_prob = d_hp['input_keep_prob']
    output_keep_prob = d_hp['output_keep_prob']
    state_keep_prob = d_hp['state_keep_prob']

    cells = []
    num_layers = hparams['num_layers']
    cell_kwargs = hparams['kwargs'].todict()
    # rename 'num_units' to 'hidden_size' following PyTorch conventions
    cell_kwargs['hidden_size'] = cell_kwargs['num_units']
    del cell_kwargs['num_units']

    for layer_i in range(num_layers):
        # Create the basic cell
        cell_type = hparams["type"]
        if layer_i == 0:
            cell_kwargs['input_size'] = input_size
        else:
            cell_kwargs['input_size'] = cell_kwargs['hidden_size']
        if not isinstance(cell_type, str) and not isinstance(cell_type, type):
            if num_layers > 1:
                raise ValueError(
                    "If 'num_layers'>1, then 'type' must be a cell class or "
                    "its name/module path, rather than a cell instance.")
        cell_modules = [
            'texar.core.cell_wrappers',  # prefer our wrappers
            'torch.nn.modules.rnn',
            'texar.custom'
        ]
        cell = utils.check_or_get_instance(cell_type, cell_kwargs,
                                           cell_modules)
        if isinstance(cell, nn.RNNCellBase):
            cell = wrappers.wrap_builtin_cell(cell)

        # Optionally add dropout
        if (input_keep_prob < 1.0 or output_keep_prob < 1.0
                or state_keep_prob < 1.0):
            # TODO: Would this result in non-final layer outputs being
            #       dropped twice?
            cell = wrappers.DropoutWrapper(
                cell=cell,
                input_keep_prob=input_keep_prob,
                output_keep_prob=output_keep_prob,
                state_keep_prob=state_keep_prob,
                variational_recurrent=variational_recurrent)

        # Optionally add residual and highway connections
        if layer_i > 0:
            if hparams['residual']:
                cell = wrappers.ResidualWrapper(cell)
            if hparams['highway']:
                cell = wrappers.HighwayWrapper(cell)

        cells.append(cell)

    if hparams['num_layers'] > 1:
        cell = wrappers.MultiRNNCell(cells)
    else:
        cell = cells[0]

    return cell
    def __init__(self,
                 env_config,
                 sess=None,
                 qnet=None,
                 target=None,
                 qnet_kwargs=None,
                 qnet_caller_kwargs=None,
                 replay_memory=None,
                 replay_memory_kwargs=None,
                 exploration=None,
                 exploration_kwargs=None,
                 hparams=None):
        EpisodicAgentBase.__init__(self, env_config, hparams)

        self._sess = sess
        self._cold_start_steps = self._hparams.cold_start_steps
        self._sample_batch_size = self._hparams.sample_batch_size
        self._update_period = self._hparams.update_period
        self._discount_factor = self._hparams.discount_factor
        self._target_update_strategy = self._hparams.target_update_strategy
        self._num_actions = self._env_config.action_space.high - \
                            self._env_config.action_space.low

        with tf.variable_scope(self.variable_scope):
            if qnet is None:
                kwargs = utils.get_instance_kwargs(qnet_kwargs,
                                                   self._hparams.qnet_hparams)
                qnet = utils.check_or_get_instance(
                    ins_or_class_or_name=self._hparams.qnet_type,
                    kwargs=kwargs,
                    module_paths=['texar.modules', 'texar.custom'])
                target = utils.check_or_get_instance(
                    ins_or_class_or_name=self._hparams.qnet_type,
                    kwargs=kwargs,
                    module_paths=['texar.modules', 'texar.custom'])
            self._qnet = qnet
            self._target = target
            self._qnet_caller_kwargs = qnet_caller_kwargs or {}

            if replay_memory is None:
                kwargs = utils.get_instance_kwargs(
                    replay_memory_kwargs, self._hparams.replay_memory_hparams)
                replay_memory = utils.check_or_get_instance(
                    ins_or_class_or_name=self._hparams.replay_memory_type,
                    kwargs=kwargs,
                    module_paths=['texar.core', 'texar.custom'])
            self._replay_memory = replay_memory

            if exploration is None:
                kwargs = utils.get_instance_kwargs(
                    exploration_kwargs, self._hparams.exploration_hparams)
                exploration = utils.check_or_get_instance(
                    ins_or_class_or_name=self._hparams.exploration_type,
                    kwargs=kwargs,
                    module_paths=['texar.core', 'texar.custom'])
            self._exploration = exploration

        self._build_graph()

        self._observ = None
        self._action = None
        self._timestep = 0
示例#15
0
    def __init__(self,
                 input_size: int,
                 encoder_output_size: int,
                 vocab_size: int,
                 cell: Optional[RNNCellBase] = None,
                 output_layer: Optional[Union[nn.Module, torch.Tensor]] = None,
                 cell_input_fn: Optional[Callable[[torch.Tensor, torch.Tensor],
                                                  torch.Tensor]] = None,
                 hparams: Optional[HParams] = None):

        super().__init__(cell=cell,
                         input_size=input_size,
                         vocab_size=vocab_size,
                         output_layer=output_layer,
                         hparams=hparams)

        attn_hparams = self._hparams['attention']
        attn_kwargs = attn_hparams['kwargs'].todict()

        # Parse the `probability_fn` argument
        if 'probability_fn' in attn_kwargs:
            prob_fn = attn_kwargs['probability_fn']
            if prob_fn is not None and not callable(prob_fn):
                prob_fn = get_function(prob_fn, ['torch.nn.functional',
                                                 'texar.core'])
            attn_kwargs['probability_fn'] = prob_fn

        # Parse `encoder_output_size` and `decoder_output_size` arguments
        if attn_hparams['type'] in ['BahdanauAttention',
                                    'BahdanauMonotonicAttention']:
            attn_kwargs.update({"decoder_output_size": self._cell.hidden_size})
        attn_kwargs.update({"encoder_output_size": encoder_output_size})

        attn_modules = ['texar.core']
        self.attention_mechanism: AttentionMechanism
        self.attention_mechanism = check_or_get_instance(
            attn_hparams["type"], attn_kwargs, attn_modules,
            classtype=AttentionMechanism)

        self._attn_cell_kwargs = {
            "attention_layer_size": attn_hparams["attention_layer_size"],
            "alignment_history": attn_hparams["alignment_history"],
            "output_attention": attn_hparams["output_attention"],
        }
        self._cell_input_fn = cell_input_fn

        if attn_hparams["output_attention"] and vocab_size is not None and \
                self.attention_mechanism is not None:
            if attn_hparams["attention_layer_size"] is None:
                self._output_layer = nn.Linear(
                    encoder_output_size,
                    vocab_size)
            else:
                self._output_layer = nn.Linear(
                    sum(attn_hparams["attention_layer_size"])
                    if isinstance(attn_hparams["attention_layer_size"], list)
                    else attn_hparams["attention_layer_size"],
                    vocab_size)

        attn_cell = AttentionWrapper(
            self._cell,
            self.attention_mechanism,
            cell_input_fn=self._cell_input_fn,
            **self._attn_cell_kwargs)

        self._cell: AttentionWrapper = attn_cell

        self.memory: Optional[torch.Tensor] = None
        self.memory_sequence_length: Optional[torch.LongTensor] = None
示例#16
0
def get_rnn_cell(hparams=None, mode=None):
    """Creates an RNN cell.

    See :func:`~texar.core.default_rnn_cell_hparams` for all
    hyperparameters and default values.

    Args:
        hparams (dict or HParams, optional): Cell hyperparameters. Missing
            hyperparameters are set to default values.
        mode (optional): A Tensor taking value in
            :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
            `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout will be
            controlled by :func:`texar.global_mode`.

    Returns:
        A cell instance.

    Raises:
        ValueError: If hparams["num_layers"]>1 and hparams["type"] is a class
            instance.
        ValueError: The cell is not an
            :tf_main:`RNNCell <contrib/rnn/RNNCell>` instance.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(hparams, default_rnn_cell_hparams())

    d_hp = hparams["dropout"]
    if d_hp["variational_recurrent"] and \
            len(d_hp["input_size"]) != hparams["num_layers"]:
        raise ValueError(
            "If variational_recurrent=True, input_size must be a list of "
            "num_layers(%d) integers. Got len(input_size)=%d." %
            (hparams["num_layers"], len(d_hp["input_size"])))

    cells = []
    cell_kwargs = hparams["kwargs"].todict()
    num_layers = hparams["num_layers"]
    for layer_i in range(num_layers):
        # Create the basic cell
        cell_type = hparams["type"]
        if not is_str(cell_type) and not isinstance(cell_type, type):
            if num_layers > 1:
                raise ValueError(
                    "If 'num_layers'>1, then 'type' must be a cell class or "
                    "its name/module path, rather than a cell instance.")
        cell_modules = ['tensorflow.contrib.rnn', 'texar.custom']
        cell = utils.check_or_get_instance(cell_type, cell_kwargs,
                                           cell_modules, rnn.RNNCell)

        # Optionally add dropout
        if d_hp["input_keep_prob"] < 1.0 or \
                d_hp["output_keep_prob"] < 1.0 or \
                d_hp["state_keep_prob"] < 1.0:
            vr_kwargs = {}
            if d_hp["variational_recurrent"]:
                vr_kwargs = {
                    "variational_recurrent": True,
                    "input_size": d_hp["input_size"][layer_i],
                    "dtype": tf.float32
                }
            input_keep_prob = switch_dropout(d_hp["input_keep_prob"], mode)
            output_keep_prob = switch_dropout(d_hp["output_keep_prob"], mode)
            state_keep_prob = switch_dropout(d_hp["state_keep_prob"], mode)
            cell = rnn.DropoutWrapper(cell=cell,
                                      input_keep_prob=input_keep_prob,
                                      output_keep_prob=output_keep_prob,
                                      state_keep_prob=state_keep_prob,
                                      **vr_kwargs)

        # Optionally add residual and highway connections
        if layer_i > 0:
            if hparams["residual"]:
                cell = rnn.ResidualWrapper(cell)
            if hparams["highway"]:
                cell = rnn.HighwayWrapper(cell)

        cells.append(cell)

    if hparams["num_layers"] > 1:
        cell = rnn.MultiRNNCell(cells)
    else:
        cell = cells[0]

    return cell
示例#17
0
    def _build(self,
               distribution='MultivariateNormalDiag',
               distribution_kwargs=None,
               transform=True,
               num_samples=None):
        """Samples from a distribution and optionally performs transformation
        with an MLP layer.

        The inputs and outputs are the same as
        :class:`~texar.modules.ReparameterizedStochasticConnector` except that
        the distribution does not need to be reparameterizable, and gradient
        cannot be back-propagate through the samples.

        Args:
            distribution: A instance of subclass of
                :tf_main:`TF Distribution <distributions/Distribution>`,
                or :tf_hmpg:`tensorflow_probability Distribution <probability>`.
                Can be a class, its name or module path, or a class instance.
            distribution_kwargs (dict, optional): Keyword arguments for the
                distribution constructor. Ignored if `distribution` is a
                class instance.
            transform (bool): Whether to perform MLP transformation of the
                distribution samples. If `False`, the structure/shape of a
                sample must match :attr:`output_size`.
            num_samples (optional): An `int` or `int` Tensor. Number of samples
                to generate. If not given, generate a single sample. Note
                that if batch size has already been included in
                `distribution`'s dimensionality, `num_samples` should be
                left as `None`.

        Returns:
            A tuple (output, sample), where

            - output: A Tensor or a (nested) tuple of Tensors with the same \
            structure and size of :attr:`output_size`. The batch dimension \
            equals :attr:`num_samples` if specified, or is determined by the \
            distribution dimensionality. If :attr:`transform` is `False`, \
            :attr:`output` will be equal to :attr:`sample`.
            - sample: The sample from the distribution, prior to transformation.

        Raises:
            ValueError: The output does not match :attr:`output_size`.
        """
        dstr = check_or_get_instance(distribution, distribution_kwargs, [
            "tensorflow.distributions", "tensorflow_probability.distributions",
            "tensorflow.contrib.distributions", "texar.custom"
        ])

        if num_samples:
            sample = dstr.sample(num_samples)
        else:
            sample = dstr.sample()

        if dstr.event_shape == []:
            sample = tf.reshape(sample,
                                sample.shape.concatenate(tf.TensorShape(1)))

        # Disable gradients through samples
        sample = tf.stop_gradient(sample)

        sample = tf.cast(sample, tf.float32)

        if transform:
            fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom']
            activation_fn = get_function(self.hparams.activation_fn,
                                         fn_modules)
            output = _mlp_transform(sample, self._output_size, activation_fn)
        else:
            output = sample

        _assert_same_size(output, self._output_size)
        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return output, sample
示例#18
0
    def _build(self,
               distribution='MultivariateNormalDiag',
               distribution_kwargs=None,
               transform=True,
               num_samples=None):
        """Samples from a distribution and optionally performs transformation
        with an MLP layer.

        The distribution must be reparameterizable, i.e.,
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED`.

        Args:
            distribution: A
                :tf_main:`TF Distribution <contrib/distributions/Distribution>`.
                Can be a class, its name or module path, or an instance of
                a subclass.
            distribution_kwargs (dict, optional): Keyword arguments for the
                distribution constructor. Ignored if `distribution` is a
                class instance.
            transform (bool): Whether to perform MLP transformation of the
                distribution samples. If `False`, the structure/shape of a
                sample must match :attr:`output_size`.
            num_samples (optional): An `int` or `int` Tensor. Number of samples
                to generate. If not given, generate a single sample. Note
                that if batch size has already been included in
                `distribution`'s dimensionality, `num_samples` should be
                left as `None`.

        Returns:
            A tuple (output, sample), where

            - output: A Tensor or a (nested) tuple of Tensors with the same \
            structure and size of :attr:`output_size`. The batch dimension \
            equals :attr:`num_samples` if specified, or is determined by the \
            distribution dimensionality.
            - sample: The sample from the distribution, prior to transformation.

        Raises:
            ValueError: If distribution cannot be reparametrized.
            ValueError: The output does not match :attr:`output_size`.
        """
        dstr = check_or_get_instance(
            distribution, distribution_kwargs,
            ["tensorflow.contrib.distributions", "texar.custom"])

        if dstr.reparameterization_type == tf_dstr.NOT_REPARAMETERIZED:
            raise ValueError(
                "Distribution is not reparameterized: %s" % dstr.name)

        if num_samples:
            sample = dstr.sample(num_samples)
        else:
            sample = dstr.sample()

        #if dstr.event_shape == []:
        #    sample = tf.reshape(
        #        sample,
        #        sample.shape.concatenate(tf.TensorShape(1)))

        # sample = tf.cast(sample, tf.float32)
        if transform:
            fn_modules = ['tensorflow', 'tensorflow.nn', 'texar.custom']
            activation_fn = get_function(self.hparams.activation_fn, fn_modules)
            output = _mlp_transform(sample, self._output_size, activation_fn)

        _assert_same_size(output, self._output_size)

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return output, sample
示例#19
0
    def forward(self,    # type: ignore
                num_samples: Optional[Union[int, torch.Tensor]] = None,
                transform: bool = True) -> Tuple[Any, Any]:
        r"""Samples from a distribution and optionally performs transformation
        with an MLP layer.
        The distribution must be reparameterizable, i.e.,
        :python:`Distribution.has_rsample == True`.

        Args:
            num_samples (optional): An ``int`` or ``int`` ``Tensor``.
                Number of samples to generate. If not given,
                generate a single sample. Note that if batch size has
                already been included in :attr:`distribution`'s dimensionality,
                :attr:`num_samples` should be left as ``None``.
            transform (bool): Whether to perform MLP transformation of the
                distribution samples. If ``False``, the structure/shape of a
                sample must match :attr:`output_size`.


        :returns:
            A tuple (:attr:`output`, :attr:`sample`), where

            - output: A ``Tensor`` or a (nested) ``tuple`` of Tensors with
              the same structure and size of :attr:`output_size`.
              The batch dimension equals :attr:`num_samples` if specified,
              or is determined by the distribution dimensionality.
              If :attr:`transform` is `False`, it will be
              equal to :attr:`sample`.
            - sample: The sample from the distribution, prior to transformation.

            Otherwise, returns a ``Tensor`` :attr:`sample`, where
            - sample: The sample from the distribution, prior to transformation.

        Raises:
            ValueError: If distribution is not reparameterizable.
            ValueError: The output does not match :attr:`output_size`.
        """
        if isinstance(self._dstr_type, str):
            dstr: Distribution = utils.check_or_get_instance(
                self._dstr_type, self._dstr_kwargs,
                ["torch.distributions", "texar.custom"])
        else:
            dstr = self._dstr_type

        if not dstr.has_rsample:
            raise ValueError("Distribution should be reparameterizable")

        if num_samples:
            sample = dstr.rsample([num_samples])
        else:
            sample = dstr.rsample()

        if transform:
            output = _mlp_transform(
                sample,
                self._output_size,
                self._linear_layer,
                self._activation_fn)
            _assert_same_size(output, self._output_size)

        else:
            output = sample

        return output, sample