def get_function(fn_or_name, module_paths=None): """Returns the function of specified name and module. Args: fn_or_name (str or callable): Name or full path to a function, or the function itself. module_paths (list, optional): A list of paths to candidate modules to search for the function. This is used only when the function cannot be located solely based on :attr:`fn_or_name`. The first module in the list that contains the function is used. Returns: A function. """ if is_callable(fn_or_name): return fn_or_name fn = locate(fn_or_name) if (fn is None) and (module_paths is not None): for module_path in module_paths: #if module_path in _unimportable_modules: fn = locate('.'.join([module_path, fn_or_name])) if fn is not None: break #module = importlib.import_module(module_path) #if fn_name in dir(module): # fn = getattr(module, fn_name) # break if fn is None: raise ValueError("Method not found in {}: {}".format( module_paths, fn_or_name)) return fn
def _make_output_layer(output_layer, vocab_size, output_layer_bias, variable_scope): """Makes a decoder output layer. """ _vocab_size = vocab_size if is_callable(output_layer): _output_layer = output_layer elif tf.contrib.framework.is_tensor(output_layer): _vocab_size = shape_list(output_layer)[1] _output_layer = _make_output_layer_from_tensor(output_layer, _vocab_size, output_layer_bias, variable_scope) elif output_layer is None: if _vocab_size is None: raise ValueError( "Either `output_layer` or `vocab_size` must be provided. " "Set `output_layer=tf.identity` if no output layer is " "wanted.") with tf.variable_scope(variable_scope): # pylint: disable=redefined-variable-type _output_layer = tf.layers.Dense(units=_vocab_size, use_bias=output_layer_bias) else: raise ValueError( "output_layer should be a callable layer, a tensor, or None. " "Unsupported type: ", type(output_layer)) return _output_layer, _vocab_size
def _make_bucket_length_fn(self): length_fn = self._hparams.bucket_length_fn if not length_fn: length_fn = lambda x: x[self.length_name] elif not is_callable(length_fn): # pylint: disable=redefined-variable-type length_fn = utils.get_function(length_fn, ["texar.custom"]) return length_fn
def _make_bucket_length_fn(self): length_fn = self._hparams.bucket_length_fn if not length_fn: # Uses the length of the first text data i = -1 for i, hparams_i in enumerate(self._hparams.datasets): if _is_text_data(hparams_i["data_type"]): break if i < 0: raise ValueError("Undefined `length_fn`.") length_fn = lambda x: x[self.length_name(i)] elif not is_callable(length_fn): # pylint: disable=redefined-variable-type length_fn = utils.get_function(length_fn, ["texar.custom"]) return length_fn
def _make_other_transformations(other_trans_hparams, data_spec): """Creates a list of tranformation functions based on the hyperparameters. Args: other_trans_hparams (list): A list of transformation functions, names, or full paths. data_spec: An instance of :class:`texar.data._DataSpec` to be passed to transformation functions. Returns: A list of transformation functions. """ other_trans = [] for tran in other_trans_hparams: if not is_callable(tran): tran = utils.get_function(tran, ["texar.custom"]) other_trans.append(dsutils.make_partial(tran, data_spec)) return other_trans
def _parse( hparams, # pylint: disable=too-many-branches, too-many-statements default_hparams, allow_new_hparam=False): """Parses hyperparameters. Args: hparams (dict): Hyperparameters. If `None`, all hyperparameters are set to default values. default_hparams (dict): Hyperparameters with default values. If `None`,Hyperparameters are fully defined by :attr:`hparams`. allow_new_hparam (bool): If `False` (default), :attr:`hparams` cannot contain hyperparameters that are not included in :attr:`default_hparams`, except the case of :attr:`"kwargs"`. Return: A dictionary of parsed hyperparameters. Returns `None` if both :attr:`hparams` and :attr:`default_hparams` are `None`. Raises: ValueError: If :attr:`hparams` is not `None` and :attr:`default_hparams` is `None`. ValueError: If :attr:`default_hparams` contains "kwargs" not does not contains "type". """ if hparams is None and default_hparams is None: return None if hparams is None: return HParams._parse(default_hparams, default_hparams) if default_hparams is None: raise ValueError("`default_hparams` cannot be `None` if `hparams` " "is not `None`.") no_typecheck_names = default_hparams.get("@no_typecheck", []) if "kwargs" in default_hparams and "type" not in default_hparams: raise ValueError("Ill-defined hyperparameter structure: 'kwargs' " "must accompany with 'type'.") parsed_hparams = copy.deepcopy(default_hparams) # Parse recursively for params of type dictionary that are missing # in `hparams`. for name, value in default_hparams.items(): if name not in hparams and isinstance(value, dict): if name == "kwargs" and "type" in hparams and \ hparams["type"] != default_hparams["type"]: # Set params named "kwargs" to empty dictionary if "type" # takes value other than default. parsed_hparams[name] = HParams({}, {}) else: parsed_hparams[name] = HParams(value, value) from texar.utils.dtypes import is_callable # Parse hparams for name, value in hparams.items(): if name not in default_hparams: if allow_new_hparam: parsed_hparams[name] = HParams._parse_value(value, name) continue else: raise ValueError( "Unknown hyperparameter: %s. Only hyperparameters " "named 'kwargs' hyperparameters can contain new " "entries undefined in default hyperparameters." % name) if value is None: parsed_hparams[name] = \ HParams._parse_value(parsed_hparams[name]) default_value = default_hparams[name] if default_value is None: parsed_hparams[name] = HParams._parse_value(value) continue # Parse recursively for params of type dictionary. if isinstance(value, dict): if name not in no_typecheck_names \ and not isinstance(default_value, dict): raise ValueError( "Hyperparameter '%s' must have type %s, got %s" % (name, _type_name(default_value), _type_name(value))) if name == "kwargs": if "type" in hparams and \ hparams["type"] != default_hparams["type"]: # Leave "kwargs" as-is if "type" takes value # other than default. parsed_hparams[name] = HParams(value, value) else: # Allow new hyperparameters if "type" takes default # value parsed_hparams[name] = HParams(value, default_value, allow_new_hparam=True) elif name in no_typecheck_names: parsed_hparams[name] = HParams(value, value) else: parsed_hparams[name] = HParams(value, default_value, allow_new_hparam) continue # Do not type-check hyperparameter named "type" and accompanied # with "kwargs" if name == "type" and "kwargs" in default_hparams: parsed_hparams[name] = value continue if name in no_typecheck_names: parsed_hparams[name] = value elif isinstance(value, type(default_value)): parsed_hparams[name] = value elif is_callable(value) and is_callable(default_value): parsed_hparams[name] = value else: try: parsed_hparams[name] = type(default_value)(value) except TypeError: raise ValueError( "Hyperparameter '%s' must have type %s, got %s" % (name, _type_name(default_value), _type_name(value))) return parsed_hparams