示例#1
0
class IntentContainer(object):
    """
    Creates an IntentContainer object used to load and match intents

    Args:
        cache_dir (str): Place to put all saved neural networks
    """
    def __init__(self, cache_dir):
        os.makedirs(cache_dir, exist_ok=True)
        self.cache_dir = cache_dir
        self.must_train = False
        self.intents = IntentManager(cache_dir)
        self.entities = EntityManager(cache_dir)
        self.padaos = padaos.IntentContainer()
        self.train_thread = None  # type: Thread
        self.serialized_args = [
        ]  # Arguments of all calls to register intents/entities

    def clear(self):
        os.makedirs(self.cache_dir, exist_ok=True)
        self.must_train = False
        self.intents = IntentManager(self.cache_dir)
        self.entities = EntityManager(self.cache_dir)
        self.padaos = padaos.IntentContainer()
        self.train_thread = None
        self.serialized_args = []

    def instantiate_from_disk(self):
        """
        Instantiates the necessary (internal) data structures when loading persisted model from disk.
        This is done via injecting entities and intents back from cached file versions.
        """

        entity_traindata = {}
        intent_traindata = {}

        # workaround: load training data for both entities and intents since
        # padaos regex needs it for (re)compilation until TODO is cleared
        for f in os.listdir(self.cache_dir):
            if f.endswith('.entity'):
                entity_name = f[0:f.find('.entity')]
                with open(os.path.join(self.cache_dir, f), 'r') as d:
                    entity_traindata[entity_name] = [
                        line.strip() for line in d
                    ]

            elif f.endswith('.intent'):
                intent_name = f[0:f.find('.intent')]
                with open(os.path.join(self.cache_dir, f), 'r') as d:
                    intent_traindata[intent_name] = [
                        line.strip() for line in d
                    ]

        # TODO: padaos.compile (regex compilation) is redone when loading: find
        # a way to persist regex, as well!
        for f in os.listdir(self.cache_dir):

            if f.startswith('{') and f.endswith('}.hash'):
                entity_name = f[1:f.find('}.hash')]
                self.add_entity(name=entity_name,
                                lines=entity_traindata[entity_name],
                                reload_cache=False,
                                must_train=False)
            elif not f.startswith('{') and f.endswith('.hash'):
                intent_name = f[0:f.find('.hash')]
                self.add_intent(name=intent_name,
                                lines=intent_traindata[intent_name],
                                reload_cache=False,
                                must_train=False)

    @_save_args
    def add_intent(self, name, lines, reload_cache=False, must_train=True):
        """
        Creates a new intent, optionally checking the cache first

        Args:
            name (str): The associated name of the intent
            lines (list<str>): All the sentences that should activate the intent
            reload_cache: Whether to ignore cached intent if exists
        """
        self.intents.add(name, lines, reload_cache, must_train)
        self.padaos.add_intent(name, lines)
        self.must_train = must_train

    @_save_args
    def add_entity(self, name, lines, reload_cache=False, must_train=True):
        """
        Adds an entity that matches the given lines.

        Example:
            self.add_intent('weather', ['will it rain on {weekday}?'])
            self.add_entity('weekday', ['monday', 'tuesday', 'wednesday'])  # ...

        Args:
            name (str): The name of the entity
            lines (list<str>): Lines of example extracted entities
            reload_cache (bool): Whether to refresh all of cache
        """
        Entity.verify_name(name)
        self.entities.add(Entity.wrap_name(name), lines, reload_cache,
                          must_train)
        self.padaos.add_entity(name, lines)
        self.must_train = must_train

    @_save_args
    def load_entity(self,
                    name,
                    file_name,
                    reload_cache=False,
                    must_train=True):
        """
       Loads an entity, optionally checking the cache first

       Args:
           name (str): The associated name of the entity
           file_name (str): The location of the entity file
           reload_cache (bool): Whether to refresh all of cache
       """
        Entity.verify_name(name)
        self.entities.load(Entity.wrap_name(name), file_name, reload_cache)
        with open(file_name) as f:
            self.padaos.add_entity(name, f.read().split('\n'))
        self.must_train = must_train

    @_save_args
    def load_file(self, *args, **kwargs):
        """Legacy. Use load_intent instead"""
        self.load_intent(*args, **kwargs)

    @_save_args
    def load_intent(self,
                    name,
                    file_name,
                    reload_cache=False,
                    must_train=True):
        """
        Loads an intent, optionally checking the cache first

        Args:
            name (str): The associated name of the intent
            file_name (str): The location of the intent file
            reload_cache (bool): Whether to refresh all of cache
        """
        self.intents.load(name, file_name, reload_cache)
        with open(file_name) as f:
            self.padaos.add_intent(name, f.read().split('\n'))
        self.must_train = must_train

    @_save_args
    def remove_intent(self, name):
        """Unload an intent"""
        self.intents.remove(name)
        self.padaos.remove_intent(name)
        self.must_train = True

    @_save_args
    def remove_entity(self, name):
        """Unload an entity"""
        self.entities.remove(name)
        self.padaos.remove_entity(name)

    def _train(self, *args, **kwargs):
        t1 = Thread(target=self.intents.train,
                    args=args,
                    kwargs=kwargs,
                    daemon=True)
        t2 = Thread(target=self.entities.train,
                    args=args,
                    kwargs=kwargs,
                    daemon=True)
        t1.start()
        t2.start()
        t1.join()
        t2.join()
        self.entities.calc_ent_dict()

    def train(self, debug=True, force=False, single_thread=False, timeout=20):
        """
        Trains all the loaded intents that need to be updated
        If a cache file exists with the same hash as the intent file,
        the intent will not be trained and just loaded from file

        Args:
            debug (bool): Whether to print a message to stdout each time a new intent is trained
            force (bool): Whether to force training if already finished
            single_thread (bool): Whether to force running in a single thread
            timeout (float): Seconds before cancelling training
        Returns:
            bool: True if training succeeded without timeout
        """
        if not self.must_train and not force:
            return
        self.padaos.compile()
        self.train_thread = Thread(target=self._train,
                                   kwargs=dict(debug=debug,
                                               single_thread=single_thread,
                                               timeout=timeout),
                                   daemon=True)
        self.train_thread.start()
        self.train_thread.join(timeout)

        self.must_train = False
        return not self.train_thread.is_alive()

    def train_subprocess(self, *args, **kwargs):
        """
        Trains in a subprocess which provides a timeout guarantees everything shuts down properly

        Args:
            See <train>
        Returns:
            bool: True for success, False if timed out
        """
        ret = call([
            sys.executable,
            '-m',
            'padatious',
            'train',
            self.cache_dir,
            '-d',
            json.dumps(self.serialized_args),
            '-a',
            json.dumps(args),
            '-k',
            json.dumps(kwargs),
        ])
        if ret == 2:
            raise TypeError('Invalid train arguments: {} {}'.format(
                args, kwargs))
        data = self.serialized_args
        self.clear()
        self.apply_training_args(data)
        self.padaos.compile()
        if ret == 0:
            self.must_train = False
            return True
        elif ret == 10:  # timeout
            return False
        else:
            raise ValueError(
                'Training failed and returned code: {}'.format(ret))

    def calc_intents(self, query):
        """
        Tests all the intents against the query and returns
        data on how well each one matched against the query

        Args:
            query (str): Input sentence to test against intents
        Returns:
            list<MatchData>: List of intent matches
        See calc_intent() for a description of the returned MatchData
        """
        if self.must_train:
            self.train()
        intents = {} if self.train_thread and self.train_thread.is_alive(
        ) else {
            i.name: i
            for i in self.intents.calc_intents(query, self.entities)
        }
        sent = tokenize(query)
        for perfect_match in self.padaos.calc_intents(query):
            name = perfect_match['name']
            intents[name] = MatchData(name,
                                      sent,
                                      matches=perfect_match['entities'],
                                      conf=1.0)
        return list(intents.values())

    def calc_intent(self, query):
        """
        Tests all the intents against the query and returns
        match data of the best intent

        Args:
            query (str): Input sentence to test against intents
        Returns:
            MatchData: Best intent match
        """
        matches = self.calc_intents(query)
        if len(matches) == 0:
            return MatchData('', '')
        best_match = max(matches, key=lambda x: x.conf)
        best_matches = (match for match in matches
                        if match.conf == best_match.conf)
        return min(best_matches,
                   key=lambda x: sum(map(len, x.matches.values())))

    def get_training_args(self):
        return self.serialized_args

    def apply_training_args(self, data):
        for params in data:
            func_name = params.pop('__name__')
            getattr(self, func_name)(**params)
示例#2
0
class IntentContainer(object):
    """
    Creates an IntentContainer object used to load and match intents

    Args:
        cache_dir (str): Place to put all saved neural networks
    """
    def __init__(self, cache_dir):
        self.must_train = False
        self.intents = IntentManager(cache_dir)
        self.entities = EntityManager(cache_dir)
        self.padaos = padaos.IntentContainer()

    def add_intent(self, name, lines, reload_cache=False):
        """
        Creates a new intent, optionally checking the cache first

        Args:
            name (str): The associated name of the intent
            lines (list<str>): All the sentences that should activate the intent
            reload_cache: Whether to ignore cached intent if exists
        """
        self.intents.add(name, lines, reload_cache)
        self.padaos.add_intent(name, lines)
        self.must_train = True

    def add_entity(self, name, lines, reload_cache=False):
        """
        Adds an entity that matches the given lines.

        Example:
            self.add_intent('weather', ['will it rain on {weekday}?'])
            self.add_entity('{weekday}', ['monday', 'tuesday', 'wednesday'])  # ...

        Args:
            name (str): The name of the entity
            lines (list<str>): Lines of example extracted entities
            reload_cache (bool): Whether to refresh all of cache
        """
        Entity.verify_name(name)
        self.entities.add(Entity.wrap_name(name), lines, reload_cache)
        self.padaos.add_entity(name, lines)
        self.must_train = True

    def load_entity(self, name, file_name, reload_cache=False):
        """
       Loads an entity, optionally checking the cache first

       Args:
           name (str): The associated name of the entity
           file_name (str): The location of the entity file
           reload_cache (bool): Whether to refresh all of cache
       """
        Entity.verify_name(name)
        self.entities.load(Entity.wrap_name(name), file_name, reload_cache)
        with open(file_name) as f:
            self.padaos.add_entity(name, f.read().split('\n'))

    def load_file(self, *args, **kwargs):
        """Legacy. Use load_intent instead"""
        self.load_intent(*args, **kwargs)

    def load_intent(self, name, file_name, reload_cache=False):
        """
        Loads an intent, optionally checking the cache first

        Args:
            name (str): The associated name of the intent
            file_name (str): The location of the intent file
            reload_cache (bool): Whether to refresh all of cache
        """
        self.intents.load(name, file_name, reload_cache)
        with open(file_name) as f:
            self.padaos.add_intent(name, f.read().split('\n'))

    def remove_intent(self, name):
        """Unload an intent"""
        self.intents.remove(name)
        self.padaos.remove_intent(name)
        self.must_train = True

    def remove_entity(self, name):
        """Unload an entity"""
        self.entities.remove(name)
        self.padaos.remove_entity(name)

    def train(self, *args, **kwargs):
        """
        Trains all the loaded intents that need to be updated
        If a cache file exists with the same hash as the intent file,
        the intent will not be trained and just loaded from file

        Args:
            print_updates (bool): Whether to print a message to stdout
                each time a new intent is trained
            single_thread (bool): Whether to force running in a single thread
        """
        self.intents.train(*args, **kwargs)
        self.entities.train(*args, **kwargs)
        self.entities.calc_ent_dict()
        self.padaos.compile()
        self.must_train = False

    def calc_intents(self, query):
        """
        Tests all the intents against the query and returns
        data on how well each one matched against the query

        Args:
            query (str): Input sentence to test against intents
        Returns:
            list<MatchData>: List of intent matches
        See calc_intent() for a description of the returned MatchData
        """
        if self.must_train:
            self.train()
        intents = {
            i.name: i
            for i in self.intents.calc_intents(query, self.entities)
        }
        for perfect_match in self.padaos.calc_intents(query):
            intent = intents.get(perfect_match['name'])
            if intent:
                intent.conf = 1.0
                intent.matches = perfect_match['entities']
        return list(intents.values())

    def calc_intent(self, query):
        """
        Tests all the intents against the query and returns
        match data of the best intent

        Args:
            query (str): Input sentence to test against intents
        Returns:
            MatchData: Best intent match
        """
        matches = self.calc_intents(query)
        if len(matches) == 0:
            return MatchData('', '')
        best_match = max(matches, key=lambda x: x.conf)
        best_matches = (match for match in matches
                        if match.conf == best_match.conf)
        return min(best_matches,
                   key=lambda x: sum(map(len, x.matches.values())))
示例#3
0
class IntentContainer(object):
    """
    Creates an IntentContainer object used to load and match intents

    Args:
        cache_dir (str): Place to put all saved neural networks
    """
    def __init__(self, cache_dir):
        self.intents = IntentManager(cache_dir)
        self.entities = EntityManager(cache_dir)

    def add_intent(self, *args, **kwargs):
        """
        Creates a new intent, optionally checking the cache first

        Args:
            name (str): The associated name of the intent
            lines (list<str>): All the sentences that should activate the intent
            reload_cache: Whether to ignore cached intent if exists
        """
        self.intents.add(*args, **kwargs)

    def add_entity(self, name, *args, **kwargs):
        """
        Adds an entity that matches the given lines.

        Example:
            self.add_intent('weather', ['will it rain on {weekday}?'])
            self.add_entity('{weekday}', ['monday', 'tuesday', 'wednesday'])  # ...

        Args:
            name (str): The name of the entity
            lines (list<str>): Lines of example extracted entities
            reload_cache (bool): Whether to refresh all of cache
        """
        Entity.verify_name(name)
        self.entities.add(Entity.wrap_name(name), *args, **kwargs)

    def load_entity(self, name, *args, **kwargs):
        """
       Loads an entity, optionally checking the cache first

       Args:
           name (str): The associated name of the entity
           file_name (str): The location of the entity file
           reload_cache (bool): Whether to refresh all of cache
       """
        Entity.verify_name(name)
        self.entities.load(Entity.wrap_name(name), *args, **kwargs)

    def load_file(self, *args, **kwargs):
        """Legacy. Use load_intent instead"""
        self.load_intent(*args, **kwargs)

    def load_intent(self, *args, **kwargs):
        """
        Loads an intent, optionally checking the cache first

        Args:
            name (str): The associated name of the intent
            file_name (str): The location of the intent file
            reload_cache (bool): Whether to refresh all of cache
        """
        self.intents.load(*args, **kwargs)

    def train(self, *args, **kwargs):
        """
        Trains all the loaded intents that need to be updated
        If a cache file exists with the same hash as the intent file,
        the intent will not be trained and just loaded from file

        Args:
            print_updates (bool): Whether to print a message to stdout
                each time a new intent is trained
            single_thread (bool): Whether to force running in a single thread
        """
        self.intents.train(*args, **kwargs)
        self.entities.train(*args, **kwargs)
        self.entities.calc_ent_dict()

    def calc_intents(self, query):
        """
        Tests all the intents against the query and returns
        data on how well each one matched against the query

        Args:
            query (str): Input sentence to test against intents
        Returns:
            list<MatchData>: List of intent matches
        See calc_intent() for a description of the returned MatchData
        """
        return self.intents.calc_intents(query, self.entities)

    def calc_intent(self, query):
        """
        Tests all the intents against the query and returns
        match data of the best intent

        Args:
            query (str): Input sentence to test against intents
        Returns:
            MatchData: Best intent match
        """
        return self.intents.calc_intent(query, self.entities)