def check_or_get_class(class_or_name, module_path=None, superclass=None): """Returns the class and checks if the class inherits :attr:`superclass`. Args: class_or_name: Name or full path to the class, or the class itself. module_paths (list, optional): Paths to candidate modules to search for the class. This is used if :attr:`class_or_name` is a string and the class cannot be located solely based on :attr:`class_or_name`. The first module in the list that contains the class is used. superclass (optional): A (list of) classes that the target class must inherit. Returns: The target class. Raises: ValueError: If class is not found based on :attr:`class_or_name` and :attr:`module_paths`. TypeError: If class does not inherits :attr:`superclass`. """ class_ = class_or_name if is_str(class_): class_ = get_class(class_, module_path) if superclass is not None: if not issubclass(class_, superclass): raise TypeError("A subclass of {} is expected. Got: {}".format( superclass, class_)) return class_
def _recur_join(s): if len(s) == 0: return '' elif is_str(s[0]): return sep.join(s) else: s_ = [_recur_join(si) for si in s] return _maybe_list_to_array(s_, s)
def _recur_split(s, dtype_as): """Splits (possibly nested list of) strings recursively. """ if is_str(s): return _maybe_list_to_array(s.split(), dtype_as) else: s_ = [_recur_split(si, dtype_as) for si in s] return _maybe_list_to_array(s_, s)
def _recur_strip(s): if is_str(s): if bos_token == '': return ' '.join(s.strip().split()) else: return ' '.join(s.strip().split()).replace(bos_token + ' ', '') else: s_ = [_recur_strip(si) for si in s] return _maybe_list_to_array(s_, s)
def _recur_strip(s): if is_str(s): s_tokens = s.split() if eos_token in s_tokens: return ' '.join(s_tokens[:s_tokens.index(eos_token)]) else: return s else: s_ = [_recur_strip(si) for si in s] return _maybe_list_to_array(s_, s)
def check_or_get_instance_with_redundant_kwargs(ins_or_class_or_name, kwargs, module_paths=None, classtype=None): """Returns a class instance and checks types. Only those keyword arguments in :attr:`kwargs` that are included in the class construction method are used. Args: ins_or_class_or_name: Can be of 3 types: - A class to instantiate. - A string of the name or module path to a class to \ instantiate. - The class instance to check types. kwargs (dict): Keyword arguments for the class constructor. module_paths (list, optional): Paths to candidate modules to search for the class. This is used if the class cannot be located solely based on :attr:`class_name`. The first module in the list that contains the class is used. classtype (optional): A (list of) classes of which the instance must be an instantiation. Raises: ValueError: If class is not found based on :attr:`class_name` and :attr:`module_paths`. ValueError: If :attr:`kwargs` contains arguments that are invalid for the class construction. TypeError: If the instance is not an instantiation of :attr:`classtype`. """ ret = ins_or_class_or_name if is_str(ret) or isinstance(ret, type): ret = get_instance_with_redundant_kwargs(ret, kwargs, module_paths) if classtype is not None: if not isinstance(ret, classtype): raise TypeError("An instance of {} is expected. Got: {}".format( classtype, ret)) return ret
def get_instance(class_or_name, kwargs, module_paths=None): """Creates a class instance. Args: class_or_name: A class, or its name or full path to a class to instantiate. kwargs (dict): Keyword arguments for the class constructor. module_paths (list, optional): Paths to candidate modules to search for the class. This is used if the class cannot be located solely based on :attr:`class_name`. The first module in the list that contains the class is used. Returns: A class instance. Raises: ValueError: If class is not found based on :attr:`class_or_name` and :attr:`module_paths`. ValueError: If :attr:`kwargs` contains arguments that are invalid for the class construction. """ # Locate the class class_ = class_or_name if is_str(class_): class_ = get_class(class_, module_paths) # Check validity of arguments class_args = set(get_args(class_.__init__)) if kwargs is None: kwargs = {} for key in kwargs.keys(): if key not in class_args: raise ValueError( "Invalid argument for class %s.%s: %s, valid args: %s" % (class_.__module__, class_.__name__, key, list(class_args))) return class_(**kwargs)
def get_layer(hparams): r"""Makes a layer instance. The layer must be an instance of :tf_main:`tf.layers.Layer <layers/Layer>`. Args: hparams (dict or HParams): Hyperparameters of the layer, with structure: .. code-block:: python { "type": "LayerClass", "kwargs": { # Keyword arguments of the layer class # ... } } Here: `"type"`: str or layer class or layer instance The layer type. This can be - The string name or full module path of a layer class. If the class name is provided, the class must be in module :tf_main:`tf.layers <layers>`, :mod:`texar.tf.core`, or :mod:`texar.tf.custom`. - A layer class. - An instance of a layer class. For example .. code-block:: python "type": "Conv1D" # class name "type": "texar.tf.core.MaxReducePooling1D" # module path "type": "my_module.MyLayer" # module path "type": tf.layers.Conv2D # class "type": Conv1D(filters=10, kernel_size=2) # cell instance "type": MyLayer(...) # cell instance `"kwargs"`: dict A dictionary of keyword arguments for constructor of the layer class. Ignored if :attr:`"type"` is a layer instance. - Arguments named "activation" can be a callable, or a `str` of the name or module path to the activation function. - Arguments named "\*_regularizer" and "\*_initializer" can be a class instance, or a `dict` of hyperparameters of respective regularizers and initializers. See - Arguments named "\*_constraint" can be a callable, or a `str` of the name or full path to the constraint function. Returns: A layer instance. If ``hparams["type"]`` is a layer instance, returns it directly. Raises: ValueError: If :attr:`hparams` is `None`. ValueError: If the resulting layer is not an instance of :tf_main:`tf.layers.Layer <layers/Layer>`. """ if hparams is None: raise ValueError("`hparams` must not be `None`.") layer_type = hparams["type"] if not is_str(layer_type) and not isinstance(layer_type, type): layer = layer_type else: layer_modules = [ "tensorflow.layers", "texar.tf.core", "texar.tf.custom" ] layer_class = utils.check_or_get_class(layer_type, layer_modules) if isinstance(hparams, dict): default_kwargs = _layer_class_to_default_kwargs_map.get( layer_class, {}) default_hparams = {"type": layer_type, "kwargs": default_kwargs} hparams = HParams(hparams, default_hparams) kwargs = {} for k, v in hparams.kwargs.items(): if k.endswith('_regularizer'): kwargs[k] = get_regularizer(v) elif k.endswith('_initializer'): kwargs[k] = get_initializer(v) elif k.endswith('activation'): kwargs[k] = get_activation_fn(v) elif k.endswith('_constraint'): kwargs[k] = get_constraint_fn(v) else: kwargs[k] = v layer = utils.get_instance(layer_type, kwargs, layer_modules) if not isinstance(layer, tf.layers.Layer): raise ValueError("layer must be an instance of `tf.layers.Layer`.") return layer
def get_rnn_cell(hparams=None, mode=None): """Creates an RNN cell. See :func:`~texar.tf.core.default_rnn_cell_hparams` for all hyperparameters and default values. Args: hparams (dict or HParams, optional): Cell hyperparameters. Missing hyperparameters are set to default values. mode (optional): A Tensor taking value in :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout will be controlled by :func:`texar.tf.global_mode`. Returns: A cell instance. Raises: ValueError: If hparams["num_layers"]>1 and hparams["type"] is a class instance. ValueError: The cell is not an :tf_main:`RNNCell <contrib/rnn/RNNCell>` instance. """ if hparams is None or isinstance(hparams, dict): hparams = HParams(hparams, default_rnn_cell_hparams()) d_hp = hparams["dropout"] if d_hp["variational_recurrent"] and \ len(d_hp["input_size"]) != hparams["num_layers"]: raise ValueError( "If variational_recurrent=True, input_size must be a list of " "num_layers(%d) integers. Got len(input_size)=%d." % (hparams["num_layers"], len(d_hp["input_size"]))) cells = [] cell_kwargs = hparams["kwargs"].todict() num_layers = hparams["num_layers"] for layer_i in range(num_layers): # Create the basic cell cell_type = hparams["type"] if not is_str(cell_type) and not isinstance(cell_type, type): if num_layers > 1: raise ValueError( "If 'num_layers'>1, then 'type' must be a cell class or " "its name/module path, rather than a cell instance.") cell_modules = [ 'tensorflow.nn.rnn_cell', 'tensorflow.contrib.rnn', 'texar.tf.custom' ] cell = utils.check_or_get_instance(cell_type, cell_kwargs, cell_modules, rnn.RNNCell) # Optionally add dropout if d_hp["input_keep_prob"] < 1.0 or \ d_hp["output_keep_prob"] < 1.0 or \ d_hp["state_keep_prob"] < 1.0: vr_kwargs = {} if d_hp["variational_recurrent"]: vr_kwargs = { "variational_recurrent": True, "input_size": d_hp["input_size"][layer_i], "dtype": tf.float32 } input_keep_prob = switch_dropout(d_hp["input_keep_prob"], mode) output_keep_prob = switch_dropout(d_hp["output_keep_prob"], mode) state_keep_prob = switch_dropout(d_hp["state_keep_prob"], mode) cell = rnn.DropoutWrapper(cell=cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob, state_keep_prob=state_keep_prob, **vr_kwargs) # Optionally add residual and highway connections if layer_i > 0: if hparams["residual"]: cell = rnn.ResidualWrapper(cell) if hparams["highway"]: cell = rnn.HighwayWrapper(cell) cells.append(cell) if hparams["num_layers"] > 1: cell = rnn.MultiRNNCell(cells) else: cell = cells[0] return cell
def _maybe_name_to_id(self, name_or_id): if is_str(name_or_id): if name_or_id not in self._name_to_id: raise ValueError("Unknown data name: {}".format(name_or_id)) return self._name_to_id[name_or_id] return name_or_id
def _maybe_str_to_list(list_or_str): if is_str(list_or_str): return list_or_str.split() return list_or_str