Exemplo n.º 1
0
def get_function(fn_or_name, module_paths=None):
    """Returns the function of specified name and module.

    Args:
        fn_or_name (str or callable): Name or full path to a function, or the
            function itself.
        module_paths (list, optional): A list of paths to candidate modules to
            search for the function. This is used only when the function
            cannot be located solely based on :attr:`fn_or_name`. The first
            module in the list that contains the function is used.

    Returns:
        A function.
    """
    if is_callable(fn_or_name):
        return fn_or_name

    fn = locate(fn_or_name)
    if (fn is None) and (module_paths is not None):
        for module_path in module_paths:
            #if module_path in _unimportable_modules:
            fn = locate('.'.join([module_path, fn_or_name]))
            if fn is not None:
                break
            #module = importlib.import_module(module_path)
            #if fn_name in dir(module):
            #    fn = getattr(module, fn_name)
            #    break

    if fn is None:
        raise ValueError("Method not found in {}: {}".format(
            module_paths, fn_or_name))

    return fn
Exemplo n.º 2
0
def _make_output_layer(output_layer, vocab_size, output_layer_bias,
                       variable_scope):
    """Makes a decoder output layer.
    """
    _vocab_size = vocab_size
    if is_callable(output_layer):
        _output_layer = output_layer
    elif tf.contrib.framework.is_tensor(output_layer):
        _vocab_size = shape_list(output_layer)[1]
        _output_layer = _make_output_layer_from_tensor(output_layer,
                                                       _vocab_size,
                                                       output_layer_bias,
                                                       variable_scope)
    elif output_layer is None:
        if _vocab_size is None:
            raise ValueError(
                "Either `output_layer` or `vocab_size` must be provided. "
                "Set `output_layer=tf.identity` if no output layer is "
                "wanted.")
        with tf.variable_scope(variable_scope):
            # pylint: disable=redefined-variable-type
            _output_layer = tf.layers.Dense(units=_vocab_size,
                                            use_bias=output_layer_bias)
    else:
        raise ValueError(
            "output_layer should be a callable layer, a tensor, or None. "
            "Unsupported type: ", type(output_layer))

    return _output_layer, _vocab_size
Exemplo n.º 3
0
 def _make_bucket_length_fn(self):
     length_fn = self._hparams.bucket_length_fn
     if not length_fn:
         length_fn = lambda x: x[self.length_name]
     elif not is_callable(length_fn):
         # pylint: disable=redefined-variable-type
         length_fn = utils.get_function(length_fn, ["texar.custom"])
     return length_fn
Exemplo n.º 4
0
 def _make_bucket_length_fn(self):
     length_fn = self._hparams.bucket_length_fn
     if not length_fn:
         # Uses the length of the first text data
         i = -1
         for i, hparams_i in enumerate(self._hparams.datasets):
             if _is_text_data(hparams_i["data_type"]):
                 break
         if i < 0:
             raise ValueError("Undefined `length_fn`.")
         length_fn = lambda x: x[self.length_name(i)]
     elif not is_callable(length_fn):
         # pylint: disable=redefined-variable-type
         length_fn = utils.get_function(length_fn, ["texar.custom"])
     return length_fn
Exemplo n.º 5
0
    def _make_other_transformations(other_trans_hparams, data_spec):
        """Creates a list of tranformation functions based on the
        hyperparameters.

        Args:
            other_trans_hparams (list): A list of transformation functions,
                names, or full paths.
            data_spec: An instance of :class:`texar.data._DataSpec` to
                be passed to transformation functions.

        Returns:
            A list of transformation functions.
        """
        other_trans = []
        for tran in other_trans_hparams:
            if not is_callable(tran):
                tran = utils.get_function(tran, ["texar.custom"])
            other_trans.append(dsutils.make_partial(tran, data_spec))
        return other_trans
Exemplo n.º 6
0
    def _parse(
            hparams,  # pylint: disable=too-many-branches, too-many-statements
            default_hparams,
            allow_new_hparam=False):
        """Parses hyperparameters.

        Args:
            hparams (dict): Hyperparameters. If `None`, all hyperparameters are
                set to default values.
            default_hparams (dict): Hyperparameters with default values.
                If `None`,Hyperparameters are fully defined by :attr:`hparams`.
            allow_new_hparam (bool): If `False` (default), :attr:`hparams`
                cannot contain hyperparameters that are not included in
                :attr:`default_hparams`, except the case of :attr:`"kwargs"`.

        Return:
            A dictionary of parsed hyperparameters. Returns `None` if both
            :attr:`hparams` and :attr:`default_hparams` are `None`.

        Raises:
            ValueError: If :attr:`hparams` is not `None` and
                :attr:`default_hparams` is `None`.
            ValueError: If :attr:`default_hparams` contains "kwargs" not does
                not contains "type".
        """
        if hparams is None and default_hparams is None:
            return None

        if hparams is None:
            return HParams._parse(default_hparams, default_hparams)

        if default_hparams is None:
            raise ValueError("`default_hparams` cannot be `None` if `hparams` "
                             "is not `None`.")
        no_typecheck_names = default_hparams.get("@no_typecheck", [])

        if "kwargs" in default_hparams and "type" not in default_hparams:
            raise ValueError("Ill-defined hyperparameter structure: 'kwargs' "
                             "must accompany with 'type'.")

        parsed_hparams = copy.deepcopy(default_hparams)

        # Parse recursively for params of type dictionary that are missing
        # in `hparams`.
        for name, value in default_hparams.items():
            if name not in hparams and isinstance(value, dict):
                if name == "kwargs" and "type" in hparams and \
                        hparams["type"] != default_hparams["type"]:
                    # Set params named "kwargs" to empty dictionary if "type"
                    # takes value other than default.
                    parsed_hparams[name] = HParams({}, {})
                else:
                    parsed_hparams[name] = HParams(value, value)

        from texar.utils.dtypes import is_callable

        # Parse hparams
        for name, value in hparams.items():
            if name not in default_hparams:
                if allow_new_hparam:
                    parsed_hparams[name] = HParams._parse_value(value, name)
                    continue
                else:
                    raise ValueError(
                        "Unknown hyperparameter: %s. Only hyperparameters "
                        "named 'kwargs' hyperparameters can contain new "
                        "entries undefined in default hyperparameters." % name)

            if value is None:
                parsed_hparams[name] = \
                    HParams._parse_value(parsed_hparams[name])

            default_value = default_hparams[name]
            if default_value is None:
                parsed_hparams[name] = HParams._parse_value(value)
                continue

            # Parse recursively for params of type dictionary.
            if isinstance(value, dict):
                if name not in no_typecheck_names \
                        and not isinstance(default_value, dict):
                    raise ValueError(
                        "Hyperparameter '%s' must have type %s, got %s" %
                        (name, _type_name(default_value), _type_name(value)))
                if name == "kwargs":
                    if "type" in hparams and \
                            hparams["type"] != default_hparams["type"]:
                        # Leave "kwargs" as-is if "type" takes value
                        # other than default.
                        parsed_hparams[name] = HParams(value, value)
                    else:
                        # Allow new hyperparameters if "type" takes default
                        # value
                        parsed_hparams[name] = HParams(value,
                                                       default_value,
                                                       allow_new_hparam=True)
                elif name in no_typecheck_names:
                    parsed_hparams[name] = HParams(value, value)
                else:
                    parsed_hparams[name] = HParams(value, default_value,
                                                   allow_new_hparam)
                continue

            # Do not type-check hyperparameter named "type" and accompanied
            # with "kwargs"
            if name == "type" and "kwargs" in default_hparams:
                parsed_hparams[name] = value
                continue

            if name in no_typecheck_names:
                parsed_hparams[name] = value
            elif isinstance(value, type(default_value)):
                parsed_hparams[name] = value
            elif is_callable(value) and is_callable(default_value):
                parsed_hparams[name] = value
            else:
                try:
                    parsed_hparams[name] = type(default_value)(value)
                except TypeError:
                    raise ValueError(
                        "Hyperparameter '%s' must have type %s, got %s" %
                        (name, _type_name(default_value), _type_name(value)))

        return parsed_hparams