Beispiel #1
0
    def __init__(self, config):
        self.registry = Registry()

        self.name = config.name
        self.model_config = {}
        if getattr(config, config.name, None):
            self.model_config = vars(getattr(config, config.name))

        self.is_independent = getattr(config, "independent", False)
Beispiel #2
0
def test_open_qa_with_bidaf_model(open_qa_config):
    claf_name = open_qa_config.name
    config = getattr(open_qa_config, claf_name, {})

    registry = Registry()
    claf_machine = registry.get(f"machine:{claf_name}")(config)

    question = utils.make_random_tokens(5)
    answer = claf_machine(question)
    answer = json.dumps(answer, indent=4, ensure_ascii=False)
Beispiel #3
0
    def __init__(
        self,
        file_paths,
        tokenizers,
        batch_sizes=[],
        readers=[],
    ):

        super(MultiTaskBertReader, self).__init__(file_paths,
                                                  MultiTaskBertDataset)
        assert len(batch_sizes) == len(readers)

        self.registry = Registry()
        self.text_columns = ["bert_input"]
        self.data_reader_factory = DataReaderFactory()

        self.tokenizers = tokenizers
        self.batch_sizes = batch_sizes

        self.dataset_batches = []
        self.dataset_helpers = []
        self.tasks = []

        for reader in readers:
            data_reader = self.make_data_reader(reader)
            batches, helpers = data_reader.read()

            self.dataset_batches.append(batches)
            self.dataset_helpers.append(helpers)

            dataset_name = reader["dataset"]
            helper = helpers["train"]
            task = self.make_task_by_reader(dataset_name, data_reader, helper)
            self.tasks.append(task)
Beispiel #4
0
class Machine:
    """
    Machine: Combine modules then make a NLP Machine

    * Args:
        config: machine_config
    """
    def __init__(self, config):
        self.config = config
        self.registry = Registry()

    def load(self):
        raise NotImplementedError("")

    @classmethod
    def load_from_config(cls, config_path):
        with open(config_path, "r", encoding="utf-8") as in_file:
            machine_config = NestedNamespace()
            machine_config.load_from_json(json.load(in_file))

        machine_name = machine_config.name
        config = getattr(machine_config, machine_name, {})
        return cls(config)

    def __call__(self, text):
        raise NotImplementedError("")

    def make_module(self, config):
        """
        Make component or experiment for claf Machine's module

        * Args:
            - config: module's config (claf.config.namespace.NestedNamespace)
        """

        module_type = config.type
        if module_type == Module.COMPONENT:
            name = config.name
            module_config = getattr(config, name, {})
            if isinstance(module_config, Namespace):
                module_config = vars(module_config)

            if getattr(config, "params", None):
                module_config.update(config.params)
            return self.registry.get(f"component:{name}")(**module_config)
        elif module_type == Module.EXPERIMENT:
            experiment_config = Namespace()
            experiment_config.checkpoint_path = config.checkpoint_path
            experiment_config.cuda_devices = getattr(config, "cuda_devices",
                                                     None)
            experiment_config.interactive = False

            experiment = Experiment(Mode.PREDICT, experiment_config)
            experiment.set_predict_mode(preload=True)
            return experiment
        else:
            raise ValueError(
                f"module_type is available only [component|experiment]. not '{module_type}'"
            )
Beispiel #5
0
    def __init__(self, config):
        self.registry = Registry()

        self.dataset = config.dataset
        file_paths = {}
        if getattr(config, "train_file_path", None) and config.train_file_path != "":
            file_paths["train"] = config.train_file_path
        if getattr(config, "valid_file_path", None) and config.valid_file_path != "":
            file_paths["valid"] = config.valid_file_path

        self.reader_config = {"file_paths": file_paths}
        if "params" in config and type(config.params) == dict:
            self.reader_config.update(config.params)
        if "tokenizers" in config:
            self.reader_config["tokenizers"] = config.tokenizers

        dataset_config = getattr(config, config.dataset, None)
        if dataset_config is not None:
            dataset_config = vars(dataset_config)
            self.reader_config.update(dataset_config)
Beispiel #6
0
class ModelFactory(Factory):
    """
    Model Factory Class

    Create Concrete model according to config.model_name
    Get model from model registries (eg. @register("model:{model_name}"))

    * Args:
        config: model config from argument (config.model)
    """
    def __init__(self, config):
        self.registry = Registry()

        self.name = config.name
        self.model_config = {}
        if getattr(config, config.name, None):
            self.model_config = vars(getattr(config, config.name))

        self.is_independent = getattr(config, "independent", False)

    @overrides
    def create(self, token_makers, **params):
        model = self.registry.get(f"model:{self.name}")

        if issubclass(model, ModelWithTokenEmbedder):
            token_embedder = self.create_token_embedder(model, token_makers)
            self.model_config["token_embedder"] = token_embedder
        elif issubclass(model, ModelWithoutTokenEmbedder):
            self.model_config["token_makers"] = token_makers
        else:
            raise ValueError(
                "Model must have inheritance. (ModelWithTokenEmbedder or ModelWithoutTokenEmbedder)"
            )

        return model(**self.model_config, **params)

    def create_token_embedder(self, model, token_makers):
        # 1. Specific case
        # ...

        # 2. Base case
        if issubclass(model, ReadingComprehension):
            return token_embedder.RCTokenEmbedder(token_makers)
        else:
            return token_embedder.BasicTokenEmbedder(token_makers)
Beispiel #7
0
class TokenMakersFactory(Factory):
    """
    TokenMakers Factory Class

    * Args:
        config: token config from argument (config.token)
    """

    LANGS = ["eng", "kor"]

    def __init__(self, config):
        self.config = config
        self.registry = Registry()

    @overrides
    def create(self):
        if getattr(self.config, "tokenizer", None):
            tokenizers = make_all_tokenizers(
                convert_config2dict(self.config.tokenizer))
        else:
            tokenizers = {}

        token_names, token_types = self.config.names, self.config.types

        if len(token_names) != len(token_types):
            raise ValueError(
                "token_names and token_types must be same length.")

        token_makers = {"tokenizers": tokenizers}
        for token_name, token_type in sorted(zip(token_names, token_types)):
            token_config = getattr(self.config, token_name, {})
            if token_config != {}:
                token_config = convert_config2dict(token_config)

            # Token (tokenizer, indexer, embedding, vocab)
            token_config = {
                "tokenizers": tokenizers,
                "indexer_config": token_config.get("indexer", {}),
                "embedding_config": token_config.get("embedding", {}),
                "vocab_config": token_config.get("vocab", {}),
            }
            token_makers[token_name] = self.registry.get(
                f"token:{token_type}")(**token_config)
        return token_makers
Beispiel #8
0
class DataReaderFactory(Factory):
    """
    DataReader Factory Class

    Create Concrete reader according to config.dataset
    Get reader from reader registries (eg. @register("reader:{reader_name}"))

    * Args:
        config: data_reader config from argument (config.data_reader)
    """
    def __init__(self, config):
        self.registry = Registry()

        self.dataset_name = config.dataset
        file_paths = {}
        if getattr(config, "train_file_path",
                   None) and config.train_file_path != "":
            file_paths["train"] = config.train_file_path
        if getattr(config, "valid_file_path",
                   None) and config.valid_file_path != "":
            file_paths["valid"] = config.valid_file_path

        self.reader_config = {"file_paths": file_paths}
        if "params" in config and type(config.params) == dict:
            self.reader_config.update(config.params)
        if "tokenizers" in config:
            self.reader_config["tokenizers"] = config.tokenizers

        dataset_config = getattr(config, config.dataset, None)
        if dataset_config is not None:
            dataset_config = vars(dataset_config)
            self.reader_config.update(dataset_config)

    @overrides
    def create(self):
        reader = self.registry.get(f"reader:{self.dataset_name.lower()}")
        return reader(**self.reader_config)
Beispiel #9
0
# -*- coding: utf-8 -*-

import json

from claf.config import args
from claf.config.registry import Registry
from claf.learn.mode import Mode
from claf import utils as common_utils

if __name__ == "__main__":
    registry = Registry()

    machine_config = args.config(mode=Mode.MACHINE)
    machine_name = machine_config.name
    config = getattr(machine_config, machine_name, {})

    claf_machine = registry.get(f"machine:{machine_name}")(config)

    while True:
        question = common_utils.get_user_input(
            f"{getattr(machine_config, 'user_input', 'Question')}")
        answer = claf_machine.get_answer(question)
        answer = json.dumps(answer, indent=4, ensure_ascii=False)
        print(
            f"{getattr(machine_config, 'system_response', 'Answer')}: {answer}"
        )
Beispiel #10
0
 def __init__(self):
     self.registry = Registry()
Beispiel #11
0
 def __init__(self, config):
     self.config = config
     self.registry = Registry()
Beispiel #12
0
 def __call__(self, obj):
     registry = Registry()
     registry.add(self.name, obj)
     return obj