コード例 #1
0
    def test_type_kwargs(self):
        r"""The the special cases involving "type" and "kwargs"
        hyperparameters.
        """
        default_hparams = {"type": "type_name", "kwargs": {"arg1": "argv1"}}

        hparams = {"type": "type_name"}
        hparams_ = HParams(hparams, default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), default_hparams["kwargs"])

        hparams = {"type": "type_name", "kwargs": {"arg2": "argv2"}}
        hparams_ = HParams(hparams, default_hparams)
        full_kwargs = {}
        full_kwargs.update(default_hparams["kwargs"])
        full_kwargs.update(hparams["kwargs"])
        self.assertEqual(hparams_.kwargs.todict(), full_kwargs)

        hparams = {"kwargs": {"arg2": "argv2"}}
        hparams_ = HParams(hparams, default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), full_kwargs)

        hparams = {"type": "type_name2"}
        hparams_ = HParams(hparams, default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), {})

        hparams = {"type": "type_name2", "kwargs": {"arg3": "argv3"}}
        hparams_ = HParams(hparams, default_hparams)
        self.assertEqual(hparams_.kwargs.todict(), hparams["kwargs"])
コード例 #2
0
ファイル: main_train.py プロジェクト: mgupta1410/forte-1
def main():
    config_data = yaml.safe_load(open("config_data.yml", "r"))
    config_model = yaml.safe_load(open("config_model.yml", "r"))
    config_preprocess = yaml.safe_load(open("config_preprocessor.yml", "r"))

    config = HParams({}, default_hparams=None)
    config.add_hparam('config_data', config_data)
    config.add_hparam('config_model', config_model)
    config.add_hparam('preprocessor', config_preprocess)

    reader = CoNLL03Reader()

    # Keep the vocabulary processor as a simple counter
    vocab_processor = CoNLL03VocabularyProcessor()

    ner_trainer = CoNLLNERTrainer()
    ner_predictor = CoNLLNERPredictor()
    ner_evaluator = CoNLLNEREvaluator()

    train_pipe = TrainPipeline(train_reader=reader,
                               trainer=ner_trainer,
                               dev_reader=reader,
                               configs=config,
                               preprocessors=[vocab_processor],
                               predictor=ner_predictor,
                               evaluator=ner_evaluator)
    train_pipe.run()
コード例 #3
0
    def setUp(self):
        self._vocab_size = 10
        self._max_time = 16
        self._batch_size = 8
        self._emb_dim = 20
        self._attention_dim = 256
        self._inputs = torch.randint(
            self._vocab_size, size=(self._batch_size, self._max_time))
        embedding = torch.rand(
            self._vocab_size, self._emb_dim, dtype=torch.float)
        self._embedder = WordEmbedder(init_value=embedding)
        self._encoder_output = torch.rand(
            self._batch_size, self._max_time, 64)

        self._test_hparams = {}  # (cell_type, is_multi) -> hparams
        for cell_type in ["RNNCell", "LSTMCell", "GRUCell"]:
            hparams = {
                "rnn_cell": {
                    'type': cell_type,
                    'kwargs': {
                        'num_units': 256,
                    },
                },
                "attention": {
                    "kwargs": {
                        "num_units": self._attention_dim
                    },
                }
            }
            self._test_hparams[(cell_type, False)] = HParams(
                hparams, AttentionRNNDecoder.default_hparams())

        hparams = {
            "rnn_cell": {
                'type': 'LSTMCell',
                'kwargs': {
                    'num_units': 256,
                },
                'num_layers': 3,
            },
            "attention": {
                "kwargs": {
                    "num_units": self._attention_dim
                },
            }
        }
        self._test_hparams[("LSTMCell", True)] = HParams(
            hparams, AttentionRNNDecoder.default_hparams())
コード例 #4
0
ファイル: optimization.py プロジェクト: VegB/VLN-Transformer
def get_grad_clip_fn(hparams: Optional[Union[HParams,
                                             Dict[str, Any]]] = None) -> \
        Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]]:
    r"""Create a gradient clipping function.

    Args:
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically. See
            :func:`~texar.torch.core.default_optimization_hparams` for
            all hyperparameters and default values.

    Returns:
        A gradient clipping function.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(hparams, default_optimization_hparams())

    hparams_grad_clip = hparams["gradient_clip"]

    grad_clip_type = hparams_grad_clip["type"]
    if grad_clip_type == "" or grad_clip_type is None:
        grad_clip_fn = None
    else:
        grad_clip_modules = ['torch.nn.utils', 'texar.torch.custom']
        grad_clip_fn = utils.get_function(grad_clip_type, grad_clip_modules)
        grad_clip_fn_kwargs = hparams_grad_clip["kwargs"].todict()
        grad_clip_fn = functools.partial(grad_clip_fn, **grad_clip_fn_kwargs)

    return grad_clip_fn
コード例 #5
0
    def __init__(self, vocab: Dict[str, int], hparams=None):
        self._hparams = HParams(hparams, self.default_hparams())

        # Initialize embeddings
        init_fn_kwargs = self._hparams.init_fn.kwargs.todict()
        if "shape" in init_fn_kwargs or "size" in init_fn_kwargs:
            raise ValueError("Argument 'shape' or 'size' must not be "
                             "specified. They are inferred automatically.")
        init_fn: Callable[..., np.ndarray]
        init_fn = utils.get_function(
            self._hparams.init_fn.type,
            ["numpy.random", "numpy", "texar.torch.custom"])

        try:
            self._word_vecs = init_fn(  # type: ignore
                size=[len(vocab), self._hparams.dim],
                **init_fn_kwargs)
        except TypeError:
            self._word_vecs = init_fn(  # type: ignore
                shape=[len(vocab), self._hparams.dim],
                **init_fn_kwargs)

        # Optionally read embeddings from file
        if self._hparams.file is not None and self._hparams.file != "":
            read_fn: Callable[[str, Dict[str, int], np.ndarray], np.ndarray]
            read_fn = utils.get_function(  # type: ignore
                self._hparams.read_fn, [
                    "texar.torch.data.embedding", "texar.torch.data",
                    "texar.torch.custom"
                ])

            self._word_vecs = read_fn(
                self._hparams.file,  # type: ignore
                vocab,
                self._word_vecs)
コード例 #6
0
 def __init__(self, hparams):
     self._hparams = HParams(hparams, self.default_hparams())
     self.img_root = self._hparams.img_root
     self.transforms = self.build_transform(self._hparams.transforms)
     self.text_root = self._hparams.text_root
     self.vocab = Vocab(self._hparams.vocab_path)
     self.pathologies = self._hparams.pathologies
コード例 #7
0
ファイル: optimization.py プロジェクト: VegB/VLN-Transformer
def get_train_op(params: Optional[Iterable[Union[torch.Tensor,
                                                 Dict[str, Any]]]] = None,
                 optimizer: Optional[Optimizer] = None,
                 scheduler: Optional[_LRScheduler] = None,
                 hparams: Optional[Union[HParams, Dict[str, Any]]] = None) -> \
        Callable[[], None]:
    r"""Creates a training op.

    Args:
        params: an iterable of :class:`torch.Tensor` or
            :class:`dict`. Specifies what Tensors should be optimized.
        optimizer: A :torch_docs:`torch.optim.Optimizer
            <optim.html#torch.optim.Optimizer>` instance.
        scheduler: A :torch_docs:`torch.optim.lr_scheduler._LRScheduler
            <optim.html#how-to-adjust-learning-rate>` instance.
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically. See
            :func:`~texar.torch.core.default_optimization_hparams` for
            all hyperparameters and default values.

    Returns:
        The callable used for variable optimization.
    """
    hparams = HParams(hparams, default_optimization_hparams())

    if params is None and optimizer is None and scheduler is None:
        raise ValueError("'params', 'optimizer' and 'scheduler' must not be "
                         "None simultaneously.")

    if scheduler is None:
        if optimizer is None and params is not None:
            optimizer = get_optimizer(params, hparams)
        if optimizer is not None:
            scheduler = get_scheduler(optimizer, hparams)
    else:
        optimizer = scheduler.optimizer  # type: ignore

    grad_clip_fn = get_grad_clip_fn(hparams)

    # TODO: Support per-parameter options in the future.
    params_list: List[nn.Parameter] = []
    for param_group in optimizer.param_groups:  # type: ignore
        params = param_group["params"]
        if isinstance(params, torch.Tensor):
            params_list.append(params)
        elif isinstance(params, list):
            params_list += params

    def _train_op():
        if grad_clip_fn is not None:
            grad_clip_fn(parameters=params_list)
        optimizer.step()
        # TODO: Ideally, scheduler should be used in the epoch level.
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    return _train_op
コード例 #8
0
    def init_from_config(self, configs: Dict):
        """
        Initialize the pipeline with the configurations

        Args:
            configs: The configurations used to create the pipeline.

        Returns:

        """
        if "Reader" not in configs or configs["Reader"] is None:
            raise KeyError('No reader in the configuration')

        reader_config = configs["Reader"]

        reader, reader_hparams = create_class_with_kwargs(
            class_name=reader_config["type"],
            class_args=reader_config.get("kwargs", {}),
            h_params=reader_config.get("hparams", {}))

        self.set_reader(reader, reader_hparams)

        if "Processors" in configs and configs["Processors"] is not None:
            for processor_configs in configs["Processors"]:

                p_class = get_class(processor_configs["type"])
                if processor_configs.get("kwargs"):
                    processor_kwargs = processor_configs["kwargs"]
                else:
                    processor_kwargs = {}
                p = p_class(**processor_kwargs)

                hparams: Dict = {}

                if processor_configs.get("hparams"):
                    # Extract the hparams section and build hparams
                    processor_hparams = processor_configs["hparams"]

                    if processor_hparams.get("config_path"):
                        filebased_hparams = yaml.safe_load(
                            open(processor_hparams["config_path"]))
                    else:
                        filebased_hparams = {}
                    hparams.update(filebased_hparams)

                    if processor_hparams.get("overwrite_configs"):
                        overwrite_hparams = processor_hparams[
                            "overwrite_configs"]
                    else:
                        overwrite_hparams = {}
                    hparams.update(overwrite_hparams)
                default_processor_hparams = p_class.default_hparams()

                processor_hparams = HParams(hparams, default_processor_hparams)
                self.add_processor(p, processor_hparams)

            self.initialize()
コード例 #9
0
    def __init__(self,
                 pretrained_model_name: Optional[str] = None,
                 cache_dir: Optional[str] = None,
                 hparams=None):

        # SpanBERT checkpoint files do not include vocabulary file, use
        # standard BERT directly when user use the pre-trained SpanBERT.
        if pretrained_model_name is not None:
            if pretrained_model_name.startswith('spanbert'):
                pretrained_model_name = pretrained_model_name.lstrip('span')
        elif hparams is not None:
            hparams = HParams(hparams, None)
            if hparams.pretrained_model_name is not None and \
                    hparams.pretrained_model_name.startswith('spanbert'):
                pretrained_model_name = \
                    hparams.pretrained_model_name.lstrip('span')

        self.load_pretrained_config(pretrained_model_name, cache_dir, hparams)

        super().__init__(hparams=None)

        self.config = {
            'tokenize_chinese_chars': self.hparams['tokenize_chinese_chars'],
            'do_lower_case': self.hparams['do_lower_case'],
            'do_basic_tokenize': self.hparams['do_basic_tokenize'],
            'non_split_tokens': self.hparams['non_split_tokens'],
        }

        if self.pretrained_model_dir is not None:
            assert self.pretrained_model_name is not None
            vocab_file = os.path.join(self.pretrained_model_dir,
                                      self._VOCAB_FILE_MAP['vocab_file']
                                      [self.pretrained_model_name])

            if self._MAX_INPUT_SIZE.get(self.pretrained_model_name):
                self.max_len = self._MAX_INPUT_SIZE[self.pretrained_model_name]
        else:
            vocab_file = self.hparams['vocab_file']
            if self.hparams.get('max_len'):
                self.max_len = self.hparams['max_len']

        if not os.path.isfile(vocab_file):
            raise ValueError("Can't find a vocabulary file at path "
                             "'{}".format(vocab_file))
        self.vocab = load_vocab(vocab_file)
        self.ids_to_tokens = dict((ids, tok) for tok, ids in self.vocab.items())

        self.do_basic_tokenize = self.hparams['do_basic_tokenize']
        if self.do_basic_tokenize:
            self.basic_tokenizer = BasicTokenizer(
                do_lower_case=self.hparams["do_lower_case"],
                never_split=self.hparams["non_split_tokens"],
                tokenize_chinese_chars=self.hparams["tokenize_chinese_chars"])
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab,
                                                      unk_token=self.unk_token)
コード例 #10
0
    def load_pretrained_config(self,
                               pretrained_model_name: Optional[str] = None,
                               cache_dir=cache_dir,
                               hparams=None):
        r"""Load paths and configurations of the pre-trained model.

        Args:
            pretrained_model_name (optional): A str with the name
                of a pre-trained model to load. If `None`, will use the model
                name in :attr:`hparams`.
            cache_dir (optional): The path to a folder in which the
                pre-trained models will be cached. If `None` (default),
                a default directory will be used.
            hparams (dict or HParams, optional): Hyperparameters. Missing
                hyperparameter will be set to default values. See
                :meth:`default_hparams` for the hyperparameter structure
                and default values.
        """
        # cache_dir = cache_dir
        if not hasattr(self, "_hparams"):
            self._hparams = HParams(hparams, self.default_hparams())
        else:
            # Probably already parsed by subclasses. We rely on subclass
            # implementations to get this right.
            # As a sanity check, we require `hparams` to be `None` in this case.
            if hparams is not None:
                raise ValueError(
                    "`self._hparams` is already assigned, but `hparams` "
                    "argument is not None.")

        self.pretrained_model_dir = None
        self.pretrained_model_name = pretrained_model_name

        if self.pretrained_model_name is None:
            self.pretrained_model_name = self._hparams.pretrained_model_name
        if self.pretrained_model_name is not None:
            self.pretrained_model_dir = self.download_checkpoint(
                self.pretrained_model_name, cache_dir)
            pretrained_model_hparams = self._transform_config(
                self.pretrained_model_name, self.pretrained_model_dir)
            self._hparams = HParams(pretrained_model_hparams,
                                    self._hparams.todict())
コード例 #11
0
 def setUp(self):
     self._vocab_size = 4
     self._max_time = 8
     self._batch_size = 16
     self._emb_dim = 20
     self._inputs = torch.randint(
         self._vocab_size, size=(self._batch_size, self._max_time))
     embedding = torch.rand(
         self._vocab_size, self._emb_dim, dtype=torch.float)
     self._embedder = WordEmbedder(init_value=embedding)
     self._hparams = HParams(None, BasicRNNDecoder.default_hparams())
コード例 #12
0
    def __init__(self,
                 hparams,
                 device: Optional[torch.device] = None,
                 vocab: Optional[Vocab] = None,
                 embedding: Optional[Embedding] = None,
                 data_source: Optional[DataSource] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        if self._hparams.dataset.variable_utterance:
            raise NotImplementedError

        # Create vocabulary
        self._bos_token = self._hparams.dataset.bos_token
        self._eos_token = self._hparams.dataset.eos_token
        self._other_transforms = self._hparams.dataset.other_transformations
        bos = utils.default_str(self._bos_token, SpecialTokens.BOS)
        eos = utils.default_str(self._eos_token, SpecialTokens.EOS)
        if vocab is None:
            self._vocab = Vocab(self._hparams.dataset.vocab_file,
                                bos_token=bos,
                                eos_token=eos)
        else:
            self._vocab = vocab

        # Create embedding
        if embedding is not None:
            self._embedding = self.make_embedding(
                self._hparams.dataset.embedding_init,
                self._vocab.token_to_id_map_py)
        else:
            self._embedding = embedding

        self._delimiter = self._hparams.dataset.delimiter
        self._max_seq_length = self._hparams.dataset.max_seq_length
        self._length_filter_mode = _LengthFilterMode(
            self._hparams.dataset.length_filter_mode)
        self._pad_length = self._max_seq_length
        if self._pad_length is not None:
            self._pad_length += sum(
                int(x != '') for x in [self._bos_token, self._eos_token])

        if data_source is None:
            if (self._length_filter_mode is _LengthFilterMode.DISCARD
                    and self._max_seq_length is not None):
                data_source = TextLineDataSource(
                    self._hparams.dataset.files,
                    compression_type=self._hparams.dataset.compression_type,
                    delimiter=self._delimiter,
                    max_length=self._max_seq_length)
            else:
                data_source = TextLineDataSource(
                    self._hparams.dataset.files,
                    compression_type=self._hparams.dataset.compression_type)

        super().__init__(data_source, hparams, device=device)
コード例 #13
0
 def __init__(self,
              hparams,
              device: Optional[torch.device] = None,
              data_source: Optional[DataSource] = None):
     self._hparams = HParams(hparams, self.default_hparams())
     self._other_transforms = self._hparams.dataset.other_transformations
     self._data_type = get_numpy_dtype(self._hparams.dataset["data_type"])
     if data_source is None:
         data_source = TextLineDataSource(
             self._hparams.dataset.files,
             compression_type=self._hparams.dataset.compression_type)
     super().__init__(data_source, hparams, device=device)
コード例 #14
0
ファイル: module_base.py プロジェクト: VegB/VLN-Transformer
 def __init__(self,
              hparams: Optional[Union[HParams, Dict[str, Any]]] = None):
     super().__init__()
     if not hasattr(self, '_hparams'):
         self._hparams = HParams(hparams, self.default_hparams())
     else:
         # Probably already parsed by subclasses. We rely on subclass
         # implementations to get this right.
         # As a sanity check, we require `hparams` to be `None` in this case.
         if hparams is not None:
             raise ValueError(
                 "`self._hparams` is already assigned, but `hparams` "
                 "argument is not None.")
コード例 #15
0
    def test_typecheck(self):
        r"""Tests type-check functionality.
        """
        def _foo():
            pass

        def _bar():
            pass

        default_hparams = {"fn": _foo, "fn_2": _foo}
        hparams = {"fn": _foo, "fn_2": _bar}
        hparams_ = HParams(hparams, default_hparams)
        self.assertEqual(hparams_.fn, default_hparams["fn"])
コード例 #16
0
    def init_from_config(self, configs: Dict):
        """
        Initialize the pipeline with the configurations

        Args:
            configs: The configurations used to create the pipeline.

        Returns:

        """
        # HParams cannot create HParams from the inner dict of list

        if "Processors" in configs and configs["Processors"] is not None:
            for processor_configs in configs["Processors"]:

                p_class = get_class(processor_configs["type"])
                if processor_configs.get("kwargs"):
                    processor_kwargs = processor_configs["kwargs"]
                else:
                    processor_kwargs = {}
                p = p_class(**processor_kwargs)

                hparams: Dict = {}

                if processor_configs.get("hparams"):
                    # Extract the hparams section and build hparams
                    processor_hparams = processor_configs["hparams"]

                    if processor_hparams.get("config_path"):
                        filebased_hparams = yaml.safe_load(
                            open(processor_hparams["config_path"]))
                    else:
                        filebased_hparams = {}
                    hparams.update(filebased_hparams)

                    if processor_hparams.get("overwrite_configs"):
                        overwrite_hparams = processor_hparams[
                            "overwrite_configs"]
                    else:
                        overwrite_hparams = {}
                    hparams.update(overwrite_hparams)
                default_processor_hparams = p_class.default_hparams()

                processor_hparams = HParams(hparams,
                                            default_processor_hparams)
                self.add_processor(p, processor_hparams)

            self.initialize()
コード例 #17
0
ファイル: optimization.py プロジェクト: VegB/VLN-Transformer
def get_scheduler(optimizer: Optimizer,
                  hparams: Optional[Union[HParams, Dict[str, Any]]] = None) -> \
        Optional[_LRScheduler]:
    r"""Creates a scheduler instance.

    Args:
        optimizer: A :torch_docs:`torch.optim.Optimizer
            <optim.html#torch.optim.Optimizer>` instance.
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically. See
            :func:`~texar.torch.core.default_optimization_hparams` for
            all hyperparameters and default values.

    :return:
        A :torch_docs:`torch.optim.lr_scheduler._LRScheduler
        <optim.html#how-to-adjust-learning-rate>` instance.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(hparams, default_optimization_hparams())

    hparams_scheduler = hparams["learning_rate_decay"]

    scheduler_type = hparams_scheduler["type"]
    if scheduler_type == "" or scheduler_type is None:
        scheduler = None
    else:
        if isinstance(scheduler_type, _LRScheduler):
            scheduler_class = scheduler_type
        else:
            scheduler_modules = [
                'torch.optim.lr_scheduler', 'texar.torch.custom'
            ]
            try:
                scheduler_class = utils.check_or_get_class(  # type: ignore
                    scheduler_type, scheduler_modules, _LRScheduler)
            except TypeError:
                raise ValueError(
                    "Unrecognized lr_scheduler. Must be string name of the "
                    "lr_scheduler class, or the class which is a subclass of "
                    "torch.optim._LRScheduler.")

        scheduler_kwargs = hparams_scheduler["kwargs"].todict()
        scheduler_kwargs.update({"optimizer": optimizer})
        scheduler = scheduler_class(**scheduler_kwargs)  # type: ignore

    return scheduler
コード例 #18
0
def get_embedding(num_embeds: Optional[int] = None,
                  init_value: Optional[torch.Tensor] = None,
                  hparams=None):
    r"""Creates embedding variable if not exists.

    Args:
        hparams (dict or HParams, optional): Embedding hyperparameters. Missing
            hyperparameters are set to default values. See
            :func:`~texar.torch.modules.default_embedding_hparams`
            for all hyperparameters and default values.

            If :attr:`init_value` is given, :attr:`hparams["initializer"]`,
            and :attr:`hparams["dim"]` are ignored.
        init_value (Tensor or numpy array, optional): Initial values of the
            embedding variable. If not given, embedding is initialized as
            specified in :attr:`hparams["initializer"]`.
        num_embeds (int, optional): The number of embedding items
            (e.g., vocabulary size). Required if :attr:`init_value` is
            not provided.

    Returns:
        A 2D :tensor:`Tensor` of the same shape with :attr:`init_value` or of
        the shape ``[num_embeds, hparams["dim"]]``.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(hparams, default_embedding_hparams())
    if init_value is None:
        initializer = layers.get_initializer(
            getattr(hparams, "initializer", None))
        # TODO Shibiao: add regularizer
        dim = hparams["dim"]
        if not isinstance(hparams["dim"], (list, tuple)):
            dim = [dim]
        embedding = torch.empty(size=[num_embeds] + dim)
        # initializer should be set by layers.get_initializer
        if initializer:
            initializer(embedding)
        else:
            torch.nn.init.xavier_uniform_(embedding)
    else:
        if torch.is_tensor(init_value):
            embedding = init_value  # Do not copy the tensor.
        else:
            embedding = torch.tensor(init_value, dtype=torch.float)

    return embedding
コード例 #19
0
ファイル: optimization.py プロジェクト: VegB/VLN-Transformer
def get_optimizer(
        params: Iterable[Union[torch.Tensor, Dict[str, Any]]],
        hparams: Optional[Union[HParams, Dict[str, Any]]] = None) -> \
        Optimizer:
    r"""Creates a optimizer instance.

    Args:
        params: an iterable of :class:`torch.Tensor` or
            :class:`dict`. Specifies what Tensors should be optimized.
        hparams (dict or HParams, optional): hyperparameters. Missing
            hyperparameters are set to default values automatically. See
            :func:`~texar.torch.core.default_optimization_hparams` for
            all hyperparameters and default values.

    :return:
        The :torch_docs:`torch.optim.Optimizer
        <optim.html#torch.optim.Optimizer>` instance specified in
        :attr:`hparams`.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(hparams, default_optimization_hparams())

    hparams_opt = hparams["optimizer"]

    optimizer_type = hparams_opt["type"]
    if isinstance(optimizer_type, Optimizer):
        optimizer_class = optimizer_type
    else:
        optimizer_modules = ['torch.optim', 'texar.torch.custom']
        try:
            optimizer_class = utils.check_or_get_class(  # type: ignore
                optimizer_type, optimizer_modules, Optimizer)
        except TypeError:
            raise ValueError(
                "Unrecognized optimizer. Must be string name of the "
                "optimizer class, or the class which is a subclass of "
                "torch.optim.Optimizer, or an instance of the subclass of "
                "Optimizer.")

    optimizer_kwargs = hparams_opt["kwargs"].todict()
    optimizer_kwargs.update({"params": params})
    optimizer = optimizer_class(**optimizer_kwargs)  # type: ignore

    return optimizer
コード例 #20
0
    def __init__(self,
                 hparams=None,
                 device: Optional[torch.device] = None,
                 data_source: Optional[DataSource] = None):
        self._hparams = HParams(hparams, self.default_hparams())

        feature_types = self._hparams.dataset.feature_original_types
        if feature_types is not None:
            warnings.warn(
                "'feature_original_types' of RecordData is deprecated. Please "
                "see default_hparams of RecordData for update instructions")
        if self._hparams.dataset.feature_types is not None:
            feature_types = self._hparams.dataset.feature_types
        elif feature_types is None:
            raise ValueError("'feature_types' must be specified")
        self._features = _convert_feature_hparams(feature_types)

        convert_types = self._hparams.dataset.feature_convert_types
        self._convert_types = {
            key: get_numpy_dtype(value)
            for key, value in convert_types.items()
        }
        for key, dtype in self._convert_types.items():
            self._features[key] = self._features[key]._replace(dtype=dtype)

        image_options = self._hparams.dataset.image_options
        if isinstance(image_options, HParams):
            image_options = [image_options]
        self._image_transforms: Dict[str, TransformFn] = {}
        for options in image_options:
            key = options.get('image_feature_name')
            if key is None or key not in self._features:
                continue
            self._image_transforms[key] = _create_image_transform(
                options.get('resize_height'), options.get('resize_width'),
                options.get('resize_method') or 'bilinear')

        self._other_transforms = self._hparams.dataset.other_transformations

        if data_source is None:
            data_source = PickleDataSource[Dict[str, Any]](
                self._hparams.dataset.files)

        super().__init__(data_source, hparams, device)
コード例 #21
0
ファイル: layers.py プロジェクト: vllgle/texar-pytorch
def get_regularizer(hparams=None):
    r"""Returns a variable regularizer instance.

    See :func:`~texar.torch.core.default_regularizer_hparams` for all
    hyperparameters and default values.

    The "type" field can be a subclass
    of :class:`~texar.torch.core.regularizers.Regularizer`, its string name
    or module path, or a class instance.

    Args:
        hparams (dict or HParams, optional): Hyperparameters. Missing
            hyperparameters are set to default values.

    Returns:
        A :class:`~texar.torch.core.regularizers.Regularizer` instance.
        `None` if :attr:`hparams` is `None` or taking the default
        hyperparameter value.

    Raises:
        ValueError: The resulting regularizer is not an instance of
            :class:`~texar.torch.core.regularizers.Regularizer`.
    """

    if hparams is None:
        return None

    if isinstance(hparams, dict):
        hparams = HParams(hparams, default_regularizer_hparams())

    rgl = utils.check_or_get_instance(
        hparams.type, hparams.kwargs.todict(),
        ["texar.torch.core.regularizers", "texar.torch.custom"])

    if not isinstance(rgl, Regularizer):
        raise ValueError("The regularizer must be an instance of "
                         "texar.torch.core.regularizers.Regularizer.")

    if isinstance(rgl, L1L2) and rgl.l1 == 0. and rgl.l2 == 0.:
        return None

    return rgl
コード例 #22
0
 def __init__(self, hparams, device: Optional[torch.device] = None,
              data_source: Optional[DataSource] = None):
     self._hparams = HParams(hparams, self.default_hparams())
     self._other_transforms = self._hparams.dataset.other_transformations
     data_type = self._hparams.dataset["data_type"]
     self._typecast_func: Union[Type[int], Type[float]]
     if data_type == "int":
         self._typecast_func = int
         self._to_data_type = np.int32
     elif data_type == "float":
         self._typecast_func = float
         self._to_data_type = np.float32
     else:
         raise ValueError("Incorrect 'data_type'. Currently 'int' and "
                          "'float' are supported. Received {}"
                          .format(data_type))
     if data_source is None:
         data_source = TextLineDataSource(
             self._hparams.dataset.files,
             compression_type=self._hparams.dataset.compression_type)
     super().__init__(data_source, hparams, device=device)
コード例 #23
0
    def __init__(self,
                 hparams,
                 device: Optional[torch.device] = None,
                 data_source: Optional[DataSource] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        self._other_transforms = self._hparams.dataset.other_transformations
        data_type = self._hparams.dataset["data_type"]
        if data_type not in get_supported_scalar_types():
            raise ValueError(f"Unsupported data type '{data_type}'")

        # In Pytorch versions < 1.1.0, "torch.uint8" is treated as "bool" type
        # hence we set self.data_type = np.uint8 here
        if data_type == "bool":
            self._data_type = get_numpy_dtype(str(torch_bool))
        else:
            self._data_type = get_numpy_dtype(data_type)

        if data_source is None:
            data_source = TextLineDataSource(
                self._hparams.dataset.files,
                compression_type=self._hparams.dataset.compression_type)
        super().__init__(data_source, hparams, device=device)
コード例 #24
0
ファイル: layers.py プロジェクト: vllgle/texar-pytorch
def get_rnn_cell(input_size, hparams=None):
    r"""Creates an RNN cell.

    See :func:`~texar.torch.core.default_rnn_cell_hparams` for all
    hyperparameters and default values.

    Args:
        input_size (int): Size of the input to the cell in the first layer.
        hparams (dict or HParams, optional): Cell hyperparameters. Missing
            hyperparameters are set to default values.

    Returns:
        A cell instance.

    Raises:
        ValueError: If ``hparams["num_layers"]``>1 and ``hparams["type"]`` is a
            class instance.
    """
    if hparams is None or isinstance(hparams, dict):
        hparams = HParams(hparams, default_rnn_cell_hparams())

    d_hp = hparams['dropout']
    variational_recurrent = d_hp['variational_recurrent']
    input_keep_prob = d_hp['input_keep_prob']
    output_keep_prob = d_hp['output_keep_prob']
    state_keep_prob = d_hp['state_keep_prob']

    cells = []
    num_layers = hparams['num_layers']
    cell_kwargs = hparams['kwargs'].todict()
    # rename 'num_units' to 'hidden_size' following PyTorch conventions
    cell_kwargs['hidden_size'] = cell_kwargs['num_units']
    del cell_kwargs['num_units']

    for layer_i in range(num_layers):
        # Create the basic cell
        cell_type = hparams["type"]
        if layer_i == 0:
            cell_kwargs['input_size'] = input_size
        else:
            cell_kwargs['input_size'] = cell_kwargs['hidden_size']
        if not isinstance(cell_type, str) and not isinstance(cell_type, type):
            if num_layers > 1:
                raise ValueError(
                    "If 'num_layers'>1, then 'type' must be a cell class or "
                    "its name/module path, rather than a cell instance.")
        cell_modules = ['texar.torch.core.cell_wrappers',  # prefer our wrappers
                        'torch.nn.modules.rnn', 'texar.torch.custom']
        cell = utils.check_or_get_instance(cell_type, cell_kwargs, cell_modules)
        if isinstance(cell, nn.RNNCellBase):
            cell = wrappers.wrap_builtin_cell(cell)

        # Optionally add dropout
        if (input_keep_prob < 1.0 or
                output_keep_prob < 1.0 or
                state_keep_prob < 1.0):
            # TODO: Would this result in non-final layer outputs being
            #       dropped twice?
            cell = wrappers.DropoutWrapper(
                cell=cell,
                input_keep_prob=input_keep_prob,
                output_keep_prob=output_keep_prob,
                state_keep_prob=state_keep_prob,
                variational_recurrent=variational_recurrent)

        # Optionally add residual and highway connections
        if layer_i > 0:
            if hparams['residual']:
                cell = wrappers.ResidualWrapper(cell)
            if hparams['highway']:
                cell = wrappers.HighwayWrapper(cell)

        cells.append(cell)

    if hparams['num_layers'] > 1:
        cell = wrappers.MultiRNNCell(cells)
    else:
        cell = cells[0]

    return cell
コード例 #25
0
ファイル: data_base.py プロジェクト: saradhix/texar-pytorch
    def __init__(self, source: DataSource[RawExample], hparams=None,
                 device: Optional[torch.device] = None):
        self._source = source
        self._hparams = HParams(hparams, self.default_hparams())
        self.device = device

        if self._hparams.num_epochs != 1:
            warnings.warn(f"'num_epochs' is set to {self._hparams.num_epochs}, "
                          f"but will be treated as 1.")

        # Check and convert strategy hyperparameters.
        self._lazy_strategy = _LazyStrategy(self._hparams.lazy_strategy)
        self._cache_strategy = _CacheStrategy(self._hparams.cache_strategy)
        if self._lazy_strategy is _LazyStrategy.NONE:
            if self._cache_strategy is not _CacheStrategy.PROCESSED:
                warnings.warn(
                    f"Using '{self._cache_strategy}' cache strategy with "
                    f"'none' lazy strategy. This will be equivalent to "
                    f"'processed' cache strategy.")
            self._cache_strategy = _CacheStrategy.PROCESSED
        elif self._lazy_strategy is _LazyStrategy.PROCESS:
            if self._cache_strategy is _CacheStrategy.NONE:
                warnings.warn(
                    f"Using 'none' cache strategy with 'process' lazy "
                    f"strategy. This will be equivalent to 'loaded' cache "
                    f"strategy.")
                self._cache_strategy = _CacheStrategy.LOADED
        self._uses_multi_processing = self._hparams.num_parallel_calls > 0
        self._parallelize_processing = self._hparams.parallelize_processing

        self._processed_cache: List[Example] = []
        self._fully_cached = False

        # If specified maximum dataset size, wrap the data source. This is done
        # before caching to avoid caching excess elements.
        if self._hparams.max_dataset_size != -1:
            self._source = _TruncatedDataSource[RawExample](
                self._source, self._hparams.max_dataset_size)

        # If processing should not be parallelized, combine processing with
        # loading by wrapping the data source. In this case, **processed** data
        # will be cached.
        if (not self._parallelize_processing and
                self._lazy_strategy is _LazyStrategy.ALL and
                self._cache_strategy is not _CacheStrategy.LOADED):
            self._transformed_source = _TransformedDataSource[
                RawExample, Example](self._source, self.process)
            self._source = self._transformed_source  # type: ignore

        # Check whether data source supports random access, and obtain dataset
        # size if it does.
        self._supports_random_access = True
        if self._lazy_strategy is not _LazyStrategy.NONE:
            try:
                self._dataset_size = len(self._source)
                _ = self._source[0]
            except TypeError:
                self._supports_random_access = False
                erase_after_access = (
                        self._cache_strategy is not _CacheStrategy.LOADED)
                self._cached_source = _CachedDataSource[RawExample](
                    self._source, erase_after_access)
                self._source = self._cached_source
                self._dataset_size = None

        # If processing should not be parallelized, combine processing with
        # loading by wrapping the data source. In this case, **loaded** data
        # will be cached.
        if (not self._parallelize_processing and
                self._cache_strategy is _CacheStrategy.LOADED):
            self._transformed_source = _TransformedDataSource[
                RawExample, Example](self._source, self.process)
            self._source = self._transformed_source  # type: ignore

        # Simplify some logic-heavy checks.
        self.__should_return_processed_examples = (
                self._lazy_strategy is not _LazyStrategy.NONE and
                self._cache_strategy is _CacheStrategy.PROCESSED and
                self._parallelize_processing)
        self.__should_call_prefetch_source = (
                self._lazy_strategy is _LazyStrategy.ALL and
                self._cache_strategy is _CacheStrategy.NONE)
        self.__should_call_prefetch_processed = (
                not self._parallelize_processing and
                self._lazy_strategy is _LazyStrategy.PROCESS and
                self._cache_strategy is _CacheStrategy.PROCESSED)
        self.__should_delete_source_in_add_cache = (
                not self._supports_random_access and
                self._parallelize_processing and
                self._uses_multi_processing and
                self._lazy_strategy is _LazyStrategy.PROCESS and
                self._cache_strategy is _CacheStrategy.PROCESSED)

        # Perform eager loading/processing if required.
        if self._lazy_strategy is _LazyStrategy.NONE:
            # Process entire dataset and cache.
            self._processed_cache = [self.process(raw_example)
                                     for raw_example in self._source]
            self._dataset_size = len(self._processed_cache)
            self._fully_cached = True
        else:
            if self._lazy_strategy is _LazyStrategy.PROCESS:
                # Load entire dataset. Note that if data source supports random
                # access, we assume it is already loaded into memory.
                if not self._supports_random_access:
                    self._prefetch_all_source()

            if self._cache_strategy is _CacheStrategy.PROCESSED:
                # Data can be processed in arbitrary order, so they need to be
                # reordered before storing in the cache list.
                self._reorder_cache: Dict[int, Example] = {}
コード例 #26
0
def main(nif_context: str, nif_page_structure: str, mapping_literals: str,
         mapping_objects: str, nif_text_links: str, redirects: str,
         info_boxs: str, output_path: str):
    # Load redirects.
    logging.info("Loading redirects")
    redirect_pickle = os.path.join(output_path, 'redirects.pickle')
    if os.path.exists(redirect_pickle):
        redirect_map: Dict[str, str] = pickle.load(open(redirect_pickle, 'rb'))
    else:
        redirect_map: Dict[str, str] = load_redirects(redirects)
        with open(redirect_pickle, 'wb') as pickle_f:
            pickle.dump(redirect_map, pickle_f)
    logging.info("Done loading.")

    # The datasets are read in two steps.
    raw_pack_dir = os.path.join(output_path, 'nif_raw')

    # First, we create the NIF reader that read the NIF in order.
    nif_pl = Pipeline()
    nif_pl.resource.update(redirects=redirect_map)

    nif_pl.set_reader(DBpediaWikiReader(),
                      config=HParams(
                          {
                              'redirect_path': redirects,
                              'nif_page_structure': nif_page_structure,
                              'nif_text_links': nif_text_links,
                          }, DBpediaWikiReader.default_configs()))

    nif_pl.add_processor(WikiArticleWriter(),
                         config=HParams(
                             {
                                 'output_dir': raw_pack_dir,
                                 'zip_pack': True,
                             }, WikiArticleWriter.default_configs()))

    nif_pl.initialize()
    logging.info('Start running the DBpedia text pipeline.')
    nif_pl.run(nif_context)

    # Second, we add info boxes to the packs with NIF.
    ib_pl = Pipeline()
    ib_pl.resource.update(redirects=redirect_map)
    ib_pl.set_reader(
        DBpediaInfoBoxReader(),
        config=HParams(
            {
                'pack_index': os.path.join(raw_pack_dir, 'article.idx'),
                'pack_dir': raw_pack_dir,
                'mapping_literals': mapping_literals,
                'mapping_objects': mapping_objects,
                'reading_log': os.path.join(output_path, 'infobox.log')
            }, DBpediaInfoBoxReader.default_configs()))

    ib_pl.add_processor(
        WikiArticleWriter(),
        config=HParams(
            {
                'output_dir': os.path.join(output_path, 'nif_info_box'),
                'zip_pack': True,
            }, WikiArticleWriter.default_configs()))

    # Now we run the info box pipeline.
    ib_pl.initialize()
    ib_pl.run(info_boxs)
コード例 #27
0
    def test_hparams(self):
        r"""Tests the HParams class.
        """
        default_hparams = {
            "str": "str",
            "list": ['item1', 'item2'],
            "dict": {
                "key1": "value1",
                "key2": "value2"
            },
            "nested_dict": {
                "dict_l2": {
                    "key1_l2": "value1_l2"
                }
            },
            "type": "type",
            "kwargs": {
                "arg1": "argv1"
            },
        }

        # Test HParams.items() function
        hparams_ = HParams(None, default_hparams)
        names = []
        for name, _ in hparams_.items():
            names.append(name)
        self.assertEqual(set(names), set(default_hparams.keys()))

        hparams = {"dict": {"key1": "new_value"}, "kwargs": {"arg2": "argv2"}}

        hparams_ = HParams(hparams, default_hparams)

        # Test HParams construction
        self.assertEqual(hparams_.str, default_hparams["str"])
        self.assertEqual(hparams_.list, default_hparams["list"])
        self.assertEqual(hparams_.dict.key1, hparams["dict"]["key1"])
        self.assertEqual(hparams_.kwargs.arg2, hparams["kwargs"]["arg2"])
        self.assertEqual(hparams_.nested_dict.dict_l2.key1_l2,
                         default_hparams["nested_dict"]["dict_l2"]["key1_l2"])

        self.assertEqual(len(hparams_), len(default_hparams))

        new_hparams = copy.deepcopy(default_hparams)
        new_hparams["dict"]["key1"] = hparams["dict"]["key1"]
        new_hparams["kwargs"].update(hparams["kwargs"])
        self.assertEqual(hparams_.todict(), new_hparams)

        self.assertTrue("dict" in hparams_)

        self.assertIsNone(hparams_.get('not_existed_name', None))
        self.assertEqual(hparams_.get('str'), default_hparams['str'])

        # Test HParams update related operations
        hparams_.str = "new_str"
        hparams_.dict = {"key3": "value3"}
        self.assertEqual(hparams_.str, "new_str")
        self.assertEqual(hparams_.dict.key3, "value3")

        hparams_.add_hparam("added_str", "added_str")
        hparams_.add_hparam("added_dict", {"key4": "value4"})
        hparams_.kwargs.add_hparam("added_arg", "added_argv")
        self.assertEqual(hparams_.added_str, "added_str")
        self.assertEqual(hparams_.added_dict.todict(), {"key4": "value4"})
        self.assertEqual(hparams_.kwargs.added_arg, "added_argv")

        # Test HParams I/O
        hparams_file = tempfile.NamedTemporaryFile()
        pickle.dump(hparams_, hparams_file)
        with open(hparams_file.name, 'rb') as hparams_file:
            hparams_loaded = pickle.load(hparams_file)
        self.assertEqual(hparams_loaded.todict(), hparams_.todict())
コード例 #28
0
    def __init__(self, hparams, device: Optional[torch.device] = None):
        self._hparams = HParams(hparams, self.default_hparams())
        # Defaultizes hyperparameters of each dataset
        datasets_hparams = self._hparams.datasets
        defaultized_datasets_hparams = []
        for hparams_i in datasets_hparams:
            data_type = hparams_i.get("data_type", None)
            defaultized_ds_hpms = HParams(hparams_i,
                                          _default_dataset_hparams(data_type))
            defaultized_datasets_hparams.append(defaultized_ds_hpms)
        self._hparams.datasets = defaultized_datasets_hparams

        self._vocab = self.make_vocab(self._hparams.datasets)
        self._embedding = self.make_embedding(self._hparams.datasets,
                                              self._vocab)

        dummy_source = SequenceDataSource[Any]([])
        name_prefix: List[str] = []
        self._names: List[Dict[str, Any]] = []
        sources: List[DataSource] = []
        filters: List[Optional[Callable[[str], bool]]] = []
        self._databases: List[DataBase] = []
        for idx, hparams_i in enumerate(self._hparams.datasets):
            data_type = _DataType(hparams_i.data_type)
            source_i: DataSource

            if _is_text_data(data_type):
                source_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type,
                    delimiter=hparams_i.delimiter)
                sources.append(source_i)
                if ((hparams_i.length_filter_mode
                     == _LengthFilterMode.DISCARD.value)
                        and hparams_i.max_seq_length is not None):

                    def _get_filter(max_seq_length):
                        return lambda x: len(x) <= max_seq_length

                    filters.append(_get_filter(hparams_i.max_seq_length))
                else:
                    filters.append(None)

                self._names.append({
                    field: connect_name(hparams_i.data_name, field)
                    for field in ["text", "text_ids", "length"]
                })

                dataset_hparams = dict_fetch(
                    hparams_i,
                    MonoTextData.default_hparams()["dataset"])
                dataset_hparams["data_name"] = None
                self._databases.append(
                    MonoTextData(hparams={"dataset": dataset_hparams},
                                 device=device,
                                 vocab=self._vocab[idx],
                                 embedding=self._embedding[idx],
                                 data_source=dummy_source))
            elif _is_scalar_data(data_type):
                source_i = TextLineDataSource(
                    hparams_i.files,
                    compression_type=hparams_i.compression_type)
                sources.append(source_i)
                filters.append(None)
                self._names.append({"data": hparams_i.data_name})

                dataset_hparams = dict_fetch(
                    hparams_i,
                    ScalarData.default_hparams()["dataset"])
                dataset_hparams["data_name"] = "data"
                self._databases.append(
                    ScalarData(hparams={"dataset": dataset_hparams},
                               device=device,
                               data_source=dummy_source))
            elif _is_record_data(data_type):
                source_i = PickleDataSource(file_paths=hparams_i.files)
                sources.append(source_i)
                self._names.append({
                    name: connect_name(hparams_i.data_name, name)
                    for name in hparams_i.feature_original_types.keys()
                })
                filters.append(None)

                dataset_hparams = dict_fetch(
                    hparams_i,
                    RecordData.default_hparams()["dataset"])
                self._databases.append(
                    RecordData(hparams={"dataset": dataset_hparams},
                               device=device,
                               data_source=dummy_source))
            else:
                raise ValueError(f"Unknown data type: {hparams_i.data_type}")

            # check for duplicate names
            for i in range(1, len(name_prefix)):
                if name_prefix[i] in name_prefix[:i - 1]:
                    raise ValueError(f"Duplicate data name: {name_prefix[i]}")

            name_prefix.append(hparams_i["data_name"])

        self._name_to_id = {v: k for k, v in enumerate(name_prefix)}

        data_source: DataSource = ZipDataSource(*sources)

        if any(filters):

            def filter_fn(data):
                return all(
                    fn(data) for fn, data in zip(filters, data)
                    if fn is not None)

            data_source = FilterDataSource(data_source, filter_fn=filter_fn)
        super().__init__(data_source, self._hparams, device)
コード例 #29
0
ファイル: layers.py プロジェクト: vllgle/texar-pytorch
def get_layer(hparams: Union[HParams, Dict[str, Any]]) -> nn.Module:
    r"""Makes a layer instance.

    The layer must be an instance of :torch_nn:`Module`.

    Args:
        hparams (dict or HParams): Hyperparameters of the layer, with
            structure:

            .. code-block:: python

                {
                    "type": "LayerClass",
                    "kwargs": {
                        # Keyword arguments of the layer class
                        # ...
                    }
                }

            Here:

            `"type"`: str or layer class or layer instance
                The layer type. This can be

                - The string name or full module path of a layer class. If
                  the class name is provided, the class must be in module
                  :torch_nn:`Module`, :mod:`texar.torch.core`, or
                  :mod:`texar.torch.custom`.
                - A layer class.
                - An instance of a layer class.

                For example

                .. code-block:: python

                    "type": "Conv1D"                               # class name
                    "type": "texar.torch.core.MaxReducePooling1D"  # module path
                    "type": "my_module.MyLayer"                    # module path
                    "type": torch.nn.Module.Linear                 # class
                    "type": Conv1D(filters=10, kernel_size=2)  # cell instance
                    "type": MyLayer(...)                       # cell instance

            `"kwargs"`: dict
                A dictionary of keyword arguments for constructor of the
                layer class. Ignored if :attr:`"type"` is a layer instance.

                - Arguments named "activation" can be a callable, or a `str` of
                  the name or module path to the activation function.
                - Arguments named "\*_regularizer" and "\*_initializer" can be a
                  class instance, or a `dict` of hyperparameters of respective
                  regularizers and initializers. See
                - Arguments named "\*_constraint" can be a callable, or a `str`
                  of the name or full path to the constraint function.

    Returns:
        A layer instance. If ``hparams["type"]`` is a layer instance, returns it
        directly.

    Raises:
        ValueError: If :attr:`hparams` is `None`.
        ValueError: If the resulting layer is not an instance of
            :torch_nn:`Module`.
    """
    if hparams is None:
        raise ValueError("`hparams` must not be `None`.")

    layer_type = hparams["type"]
    if not is_str(layer_type) and not isinstance(layer_type, type):
        layer = layer_type
    else:
        layer_modules = ["torch.nn", "texar.torch.core", "texar.torch.custom"]
        layer_class: Type[nn.Module] = utils.check_or_get_class(
            layer_type, layer_modules)
        if isinstance(hparams, dict):
            if (layer_class.__name__ == "Linear" and
                    "in_features" not in hparams["kwargs"]):
                raise ValueError("\"in_features\" should be specified for "
                                 "\"torch.nn.{}\"".format(layer_class.__name__))
            elif (layer_class.__name__ in ["Conv1d", "Conv2d", "Conv3d"] and
                  "in_channels" not in hparams["kwargs"]):
                raise ValueError("\"in_channels\" should be specified for "
                                 "\"torch.nn.{}\"".format(layer_class.__name__))
            default_kwargs = _layer_class_to_default_kwargs_map.get(
                layer_class, {})
            default_hparams = {"type": layer_type, "kwargs": default_kwargs}
            hparams = HParams(hparams, default_hparams)

        # this case needs to be handled separately because
        # :torch_nn:`Sequential`
        # does not accept kwargs
        if layer_type == "Sequential":
            names: List[str] = []
            layer = nn.Sequential()
            sub_hparams = hparams.kwargs.layers
            for hparam in sub_hparams:
                sub_layer = get_layer(hparam)
                name = utils.uniquify_str(sub_layer._get_name(), names)
                names.append(name)
                layer.add_module(name=name, module=sub_layer)
        else:
            layer = utils.get_instance(layer_type, hparams.kwargs.todict(),
                                       layer_modules)

    if not isinstance(layer, nn.Module):
        raise ValueError("layer must be an instance of `torch.nn.Module`.")

    return layer
コード例 #30
0
class PretrainedMixin(ModuleBase, ABC):
    r"""A mixin class for all pre-trained classes to inherit.
    """

    _MODEL_NAME: str
    _MODEL2URL: Dict[str, MaybeList[str]]

    pretrained_model_dir: Optional[str]

    @classmethod
    def available_checkpoints(cls) -> List[str]:
        return list(cls._MODEL2URL.keys())

    def _name_to_variable(self, name: str) -> nn.Parameter:
        r"""Find the corresponding variable given the specified name.
        """
        pointer = self
        for m_name in name.split("."):
            if m_name.isdigit():
                num = int(m_name)
                pointer = pointer[num]  # type: ignore
            else:
                pointer = getattr(pointer, m_name)
        return pointer  # type: ignore

    def load_pretrained_config(self,
                               pretrained_model_name: Optional[str] = None,
                               cache_dir: Optional[str] = None,
                               hparams=None):
        r"""Load paths and configurations of the pre-trained model.

        Args:
            pretrained_model_name (optional): A str with the name
                of a pre-trained model to load. If `None`, will use the model
                name in :attr:`hparams`.
            cache_dir (optional): The path to a folder in which the
                pre-trained models will be cached. If `None` (default),
                a default directory will be used.
            hparams (dict or HParams, optional): Hyperparameters. Missing
                hyperparameter will be set to default values. See
                :meth:`default_hparams` for the hyperparameter structure
                and default values.
        """
        if not hasattr(self, "_hparams"):
            self._hparams = HParams(hparams, self.default_hparams())
        else:
            # Probably already parsed by subclasses. We rely on subclass
            # implementations to get this right.
            # As a sanity check, we require `hparams` to be `None` in this case.
            if hparams is not None:
                raise ValueError(
                    "`self._hparams` is already assigned, but `hparams` "
                    "argument is not None.")

        self.pretrained_model_dir = None
        self.pretrained_model_name = pretrained_model_name

        if self.pretrained_model_name is None:
            self.pretrained_model_name = self._hparams.pretrained_model_name
        if self.pretrained_model_name is not None:
            self.pretrained_model_dir = self.download_checkpoint(
                self.pretrained_model_name, cache_dir)
            pretrained_model_hparams = self._transform_config(
                self.pretrained_model_name, self.pretrained_model_dir)
            self._hparams = HParams(pretrained_model_hparams,
                                    self._hparams.todict())

    def init_pretrained_weights(self, *args, **kwargs):
        if self.pretrained_model_dir:
            self._init_from_checkpoint(self.pretrained_model_name,
                                       self.pretrained_model_dir, *args,
                                       **kwargs)
        else:
            self.reset_parameters()

    def reset_parameters(self):
        r"""Initialize parameters of the pre-trained model. This method is only
        called if pre-trained checkpoints are not loaded.
        """
        pass

    @staticmethod
    def default_hparams():
        r"""Returns a dictionary of hyperparameters with default values.

        .. code-block:: python

            {
                "pretrained_model_name": None,
                "name": "pretrained_base"
            }
        """
        return {
            'pretrained_model_name': None,
            'name': "pretrained_base",
            '@no_typecheck': ['pretrained_model_name']
        }

    @classmethod
    def download_checkpoint(cls,
                            pretrained_model_name: str,
                            cache_dir: Optional[str] = None) -> str:
        r"""Download the specified pre-trained checkpoint, and return the
        directory in which the checkpoint is cached.

        Args:
            pretrained_model_name (str): Name of the model checkpoint.
            cache_dir (str, optional): Path to the cache directory. If `None`,
                uses the default directory (user's home directory).

        Returns:
            Path to the cache directory.
        """
        if pretrained_model_name in cls._MODEL2URL:
            download_path = cls._MODEL2URL[pretrained_model_name]
        else:
            raise ValueError(
                f"Pre-trained model not found: {pretrained_model_name}")

        if cache_dir is None:
            cache_path = default_download_dir(cls._MODEL_NAME)
        else:
            cache_path = Path(cache_dir)
        cache_path = cache_path / pretrained_model_name

        if not cache_path.exists():
            if isinstance(download_path, str):
                filename = download_path.split('/')[-1]
                maybe_download(download_path, cache_path, extract=True)
                folder = None
                for file in cache_path.iterdir():
                    if file.is_dir():
                        folder = file
                assert folder is not None
                (cache_path / filename).unlink()
                for file in folder.iterdir():
                    file.rename(file.parents[1] / file.name)
                folder.rmdir()
            else:
                for path in download_path:
                    maybe_download(path, cache_path)
            print(f"Pre-trained {cls._MODEL_NAME} checkpoint "
                  f"{pretrained_model_name} cached to {cache_path}")
        else:
            print(f"Using cached pre-trained {cls._MODEL_NAME} checkpoint "
                  f"from {cache_path}.")

        return str(cache_path)

    @classmethod
    @abstractmethod
    def _transform_config(cls, pretrained_model_name: str,
                          cache_dir: str) -> Dict[str, Any]:
        r"""Load the official configuration file and transform it into
        Texar-style hyperparameters.

        Args:
            pretrained_model_name (str): Name of the pre-trained model.
            cache_dir (str): Path to the cache directory.

        Returns:
            dict: Texar module hyperparameters.
        """
        raise NotImplementedError

    @abstractmethod
    def _init_from_checkpoint(self, pretrained_model_name: str, cache_dir: str,
                              **kwargs):
        r"""Initialize model parameters from weights stored in the pre-trained
        checkpoint.

        Args:
            pretrained_model_name (str): Name of the pre-trained model.
            cache_dir (str): Path to the cache directory.
            **kwargs: Additional arguments for specific models.
        """
        raise NotImplementedError