示例#1
0
    def __init__(self, cls, cache_dir):
        self.cls = cls
        self.cache = cache_dir
        self.objects = []
        self.objects_to_train = []

        self.train_data = TrainData()
示例#2
0
 def setup(self):
     self.data = TrainData()
     self.data.add_lines('hi', ['hello', 'hi', 'hi there'])
     self.data.add_lines('bye', ['goodbye', 'bye', 'bye {person}', 'see you later'])
     self.i_hi = Intent('hi')
     self.i_bye = Intent('bye')
     self.i_hi.train(self.data)
     self.i_bye.train(self.data)
示例#3
0
class TestEntityEdge:
    def setup(self):
        self.data = TrainData()
        self.data.add_lines('', ['a {word} here', 'the {word} here'])
        self.le = EntityEdge(-1, '{word}', '')
        self.re = EntityEdge(+1, '{word}', '')

    def test_match(self):
        self.le.train(self.data)
        self.re.train(self.data)
        sent = ['a', '{word}', 'here']
        assert self.le.match(sent, 1) > self.le.match(sent, 0)
        assert self.le.match(sent, 1) > self.le.match(sent, 2)
        assert self.re.match(sent, 1) > self.re.match(sent, 0)
        assert self.re.match(sent, 1) > self.re.match(sent, 2)
示例#4
0
class TestIntent:
    def setup(self):
        self.data = TrainData()
        self.data.add_lines('hi', ['hello', 'hi', 'hi there'])
        self.data.add_lines('bye', ['goodbye', 'bye', 'bye {person}', 'see you later'])
        self.i_hi = Intent('hi')
        self.i_bye = Intent('bye')
        self.i_hi.train(self.data)
        self.i_bye.train(self.data)

    def test_match(self):
        assert self.i_hi.match(['hi']).conf > self.i_hi.match(['bye']).conf
        assert self.i_hi.match(['hi']).conf > self.i_bye.match(['hi']).conf
        assert self.i_bye.match(['bye']).conf > self.i_bye.match(['hi']).conf
        assert self.i_bye.match(['bye']).conf > self.i_hi.match(['bye']).conf

        all = self.i_bye.match(['see', 'you', 'later']).conf
        assert all > self.i_hi.match(['see']).conf
        assert all > self.i_hi.match(['you']).conf
        assert all > self.i_hi.match(['later']).conf

        matches = self.i_bye.match(['bye', 'john']).matches
        assert len(matches) == 1
        assert '{person}' in matches
        assert matches['{person}'] == ['john']

    def test_save_load(self):
        if not isdir('temp'):
            mkdir('temp')
        self.i_hi.save('temp')
        self.i_bye.save('temp')

        self.i_hi = Intent.from_file('hi', 'temp')
        self.i_bye = Intent.from_file('bye', 'temp')

        self.test_match()

    def teardown(self):
        if isdir('temp'):
            rmtree('temp')
示例#5
0
class TestTrainData:
    def setup(self):
        self.data = TrainData()
        with open('temp', 'w') as f:
            f.writelines(['hi'])

    def test_add_lines(self):
        self.data.add_file('hi', 'temp')
        self.data.add_lines('bye', ['bye'])
        self.data.add_lines('other', ['other'])

        def cmp(a, b):
            return set(' '.join(i) for i in a) == set(' '.join(i) for i in b)

        assert cmp(self.data.my_sents('hi'), [['hi']])
        assert cmp(self.data.other_sents('hi'), [['bye'], ['other']])
        assert cmp(self.data.all_sents(), [['hi'], ['bye'], ['other']])

    def teardown(self):
        if isfile('temp'):
            os.remove('temp')
示例#6
0
 def setup(self):
     self.data = TrainData()
     with open('temp', 'w') as f:
         f.writelines(['hi'])
示例#7
0
class TrainingManager(object):
    """
    Manages multithreaded training of either Intents or Entities

    Args:
        cls (Type[Trainable]): Class to wrap
        cache_dir (str): Place to store cache files
    """
    def __init__(self, cls, cache_dir):
        self.cls = cls
        self.cache = cache_dir
        self.objects = []
        self.objects_to_train = []

        self.train_data = TrainData()

    def add(self, name, lines, reload_cache=False):
        hash_fn = join(self.cache, name + '.hash')
        old_hsh = None
        if isfile(hash_fn):
            with open(hash_fn, 'rb') as g:
                old_hsh = g.read()
        min_ver = splitext(padatious.__version__)[0]
        new_hsh = lines_hash([min_ver] + lines)
        if reload_cache or old_hsh != new_hsh:
            self.objects_to_train.append(self.cls(name=name, hsh=new_hsh))
        else:
            self.objects.append(self.cls.from_file(name=name, folder=self.cache))
        self.train_data.add_lines(name, lines)

    def load(self, name, file_name, reload_cache=False):
        with open(file_name) as f:
            self.add(name, f.read().split('\n'), reload_cache)

    def remove(self, name):
        self.objects = [i for i in self.objects if i.name != name]
        self.objects_to_train = [i for i in self.objects_to_train if i.name != name]
        self.train_data.remove_lines(name)

    def train(self, debug=True, single_thread=False, timeout=20):
        if not isdir(self.cache):
            mkdir(self.cache)

        train = partial(
            _train_and_save, cache=self.cache, data=self.train_data, print_updates=debug
        )

        if single_thread:
            for i in self.objects_to_train:
                train(i)
        else:
            # Train in multiple processes to disk
            pool = mp.Pool()
            try:
                pool.map_async(train, self.objects_to_train).get(timeout)
            except TimeoutError:
                if debug:
                    print('Some objects timed out while training')
            finally:
                pool.close()

        # Load saved objects from disk
        for obj in self.objects_to_train:
            try:
                self.objects.append(self.cls.from_file(name=obj.name, folder=self.cache))
            except IOError:
                if debug:
                    print('Took too long to train', obj.name)
        self.objects_to_train = []
示例#8
0
class TrainingManager(object):
    """
    Manages multithreaded training of either Intents or Entities

    Args:
        cls (Type[Trainable]): Class to wrap
        cache_dir (str): Place to store cache files
    """
    def __init__(self, cls, cache_dir):
        self.cls = cls
        self.cache = cache_dir
        self.objects = []
        self.objects_to_train = []

        self.train_data = TrainData()

    def add(self, name, lines, reload_cache=False):
        hash_fn = join(self.cache, name + '.hash')
        old_hsh = None
        if isfile(hash_fn):
            with open(hash_fn, 'rb') as g:
                old_hsh = g.read()
        min_ver = splitext(padatious.__version__)[0]
        new_hsh = lines_hash([min_ver] + lines)
        if reload_cache or old_hsh != new_hsh:
            self.objects_to_train.append(self.cls(name=name, hsh=new_hsh))
        else:
            self.objects.append(
                self.cls.from_file(name=name, folder=self.cache))
        self.train_data.add_lines(name, lines)

    def load(self, name, file_name, reload_cache=False):
        with open(file_name) as f:
            self.add(name, f.read().split('\n'), reload_cache)

    def remove(self, name):
        self.objects = [i for i in self.objects if i.name != name]
        self.objects_to_train = [
            i for i in self.objects_to_train if i.name != name
        ]
        self.train_data.remove_lines(name)

    def train(self, debug=True, single_thread=False):
        if not isdir(self.cache):
            mkdir(self.cache)

        def args(i):
            return i, self.cache, self.train_data, debug

        if single_thread:
            for i in self.objects_to_train:
                _train_and_save(*args(i))
        else:
            # Train in multiple processes to disk
            pool = mp.Pool()
            try:
                results = [
                    pool.apply_async(_train_and_save, args(i))
                    for i in self.objects_to_train
                ]

                for i in results:
                    i.get()
            finally:
                pool.close()

        # Load saved objects from disk
        for obj in self.objects_to_train:
            self.objects.append(
                self.cls.from_file(name=obj.name, folder=self.cache))
        self.objects_to_train = []
示例#9
0
 def setup(self):
     self.data = TrainData()
     self.data.add_lines('', ['a {word} here', 'the {word} here'])
     self.le = EntityEdge(-1, '{word}', '')
     self.re = EntityEdge(+1, '{word}', '')