Exemplo n.º 1
0
Arquivo: lm.py Projeto: wollmers/gecco
    def train(self, sourcefile, modelfile, **parameters):
        l = self.settings['leftcontext']
        r = self.settings['rightcontext']
        n = l + 1 + r

        self.log("Generating training instances...")
        fileprefix = modelfile.replace(".ibase","") #has been verified earlier
        classifier = TimblClassifier(fileprefix, self.gettimbloptions())
        if sourcefile.endswith(".bz2"):
            iomodule = bz2
        elif sourcefile.endswith(".gz"):
            iomodule = gzip
        else:
            iomodule = io
        with iomodule.open(sourcefile,mode='rt',encoding='utf-8') as f:
            for i, line in enumerate(f):
                if i % 100000 == 0: print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + " - " + str(i),file=sys.stderr)
                for ngram in Windower(line, n):
                    focus = ngram[l]
                    leftcontext = tuple(ngram[:l])
                    rightcontext = tuple(ngram[l+1:])
                    classifier.append( leftcontext + rightcontext , focus )

        self.log("Training classifier...")
        classifier.train()

        self.log("Saving model " + modelfile)
        classifier.save()
Exemplo n.º 2
0
    def train(self, sourcefile, modelfile, **parameters):
        if self.hapaxer:
            self.log("Training hapaxer...")
            self.hapaxer.train()

        l = self.settings['leftcontext']
        r = self.settings['rightcontext']

        self.log("Generating training instances...")
        fileprefix = modelfile.replace(".ibase",
                                       "")  #has been verified earlier
        classifier = TimblClassifier(fileprefix, self.gettimbloptions())
        if sourcefile.endswith(".bz2"):
            iomodule = bz2
        elif sourcefile.endswith(".gz"):
            iomodule = gzip
        else:
            iomodule = io

        prevword = ""
        #buffer = [("<begin>",False,'')] * l
        buffer = []
        with iomodule.open(sourcefile,
                           mode='rt',
                           encoding='utf-8',
                           errors='ignore') as f:
            for i, line in enumerate(f):
                if i % 100000 == 0:
                    print(
                        datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") +
                        " - " + str(i),
                        file=sys.stderr)
                words = [w.strip() for w in line.split(' ') if w.strip()]
                for i, word in enumerate(words):
                    if prevword in TIMBLPuncRecaseModule.PUNCTUATION:
                        punc = prevword
                    else:
                        punc = ""
                    if any(c.isalpha() for c in word):
                        buffer.append(
                            (word, word == word[0].upper() + word[1:].lower(),
                             punc))
                    if len(buffer) == l + r + 1:
                        buffer = self.addtraininstance(classifier, buffer, l,
                                                       r)
                    prevword = word
        #for i in range(0,r):
        #    buffer.append( ("<end>",False,'') )
        #    if len(buffer) == l + r + 1:
        #        buffer = self.addtraininstance(classifier, buffer,l,r)

        self.log("Training classifier...")
        classifier.train()

        self.log("Saving model " + modelfile)
        classifier.save()
Exemplo n.º 3
0
def create_classifier_and_word_freq_list(train_instances,timbl_models_folder,train_users,test_user,tweet_index):

	timbl_model_name = test_user+'.'+'_'.join(train_users)+'.'+str(tweet_index)
	classifier = TimblClassifier(timbl_models_folder+timbl_model_name,'-a 0 -k 1 +vs')
	word_frequencies = Counter()

	for instance in train_instances:
		if instance.author == test_user and instance.original_tweet_index == tweet_index:
			continue

		classifier.append( instance.features, instance.label)
		word_frequencies[instance.label]+= 1

	classifier.train()

	return classifier,word_frequencies
Exemplo n.º 4
0
    def train(self, sourcefile, modelfile, **parameters):
        if self.hapaxer:
            self.log("Training hapaxer...")
            self.hapaxer.train()

        l = self.settings['leftcontext']
        r = self.settings['rightcontext']

        self.log("Generating training instances...")
        fileprefix = modelfile.replace(".ibase","") #has been verified earlier
        classifier = TimblClassifier(fileprefix, self.gettimbloptions())
        if sourcefile.endswith(".bz2"):
            iomodule = bz2
        elif sourcefile.endswith(".gz"):
            iomodule = gzip
        else:
            iomodule = io

        prevword = ""
        #buffer = [("<begin>",False,'')] * l
        buffer = []
        with iomodule.open(sourcefile,mode='rt',encoding='utf-8',errors='ignore') as f:
            for i, line in enumerate(f):
                if i % 100000 == 0: print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + " - " + str(i),file=sys.stderr)
                words = [ w.strip() for w in line.split(' ') if w.strip() ]
                for i, word in enumerate(words):
                    if prevword in PUNCTUATION:
                        punc = prevword
                    else:
                        punc = ""
                    if any(  c.isalpha() for c in word  ):
                        buffer.append( (word, word == word[0].upper() + word[1:].lower(), punc ) )
                    if len(buffer) == l + r + 1:
                        buffer = self.addtraininstance(classifier, buffer,l,r)
                    prevword = word
        #for i in range(0,r):
        #    buffer.append( ("<end>",False,'') )
        #    if len(buffer) == l + r + 1:
        #        buffer = self.addtraininstance(classifier, buffer,l,r)

        self.log("Training classifier...")
        classifier.train()

        self.log("Saving model " + modelfile)
        classifier.save()
Exemplo n.º 5
0
    def train(self, sourcefile, modelfile, **parameters):
        if self.hapaxer:
            self.log("Training hapaxer...")
            self.hapaxer.train()

        l = self.settings['leftcontext']
        r = self.settings['rightcontext']
        n = l + 1 + r

        self.log("Generating training instances...")
        fileprefix = modelfile.replace(".ibase",
                                       "")  #has been verified earlier
        classifier = TimblClassifier(fileprefix, self.gettimbloptions())
        if sourcefile.endswith(".bz2"):
            iomodule = bz2
        elif sourcefile.endswith(".gz"):
            iomodule = gzip
        else:
            iomodule = io
        with iomodule.open(sourcefile,
                           mode='rt',
                           encoding='utf-8',
                           errors='ignore') as f:
            for i, line in enumerate(f):
                if i % 100000 == 0:
                    print(
                        datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") +
                        " - " + str(i),
                        file=sys.stderr)
                for ngram in Windower(line, n):
                    confusible = ngram[l]
                    if confusible in self.settings['confusibles']:
                        if self.hapaxer:
                            ngram = self.hapaxer(ngram)
                        leftcontext = tuple(ngram[:l])
                        rightcontext = tuple(ngram[l + 1:])
                        classifier.append(leftcontext + rightcontext,
                                          confusible)

        self.log("Training classifier...")
        classifier.train()

        self.log("Saving model " + modelfile)
        classifier.save()
Exemplo n.º 6
0
    def train(self, sourcefile, modelfile, **parameters):
        if self.hapaxer:
            self.log("Training hapaxer...")
            self.hapaxer.train()

        l = self.settings["leftcontext"]
        r = self.settings["rightcontext"]
        n = l + 1 + r

        self.log("Generating training instances...")
        fileprefix = modelfile.replace(".ibase", "")  # has been verified earlier
        classifier = TimblClassifier(fileprefix, self.gettimbloptions())
        if sourcefile.endswith(".bz2"):
            iomodule = bz2
        elif sourcefile.endswith(".gz"):
            iomodule = gzip
        else:
            iomodule = io
        with iomodule.open(sourcefile, mode="rt", encoding="utf-8", errors="ignore") as f:
            for i, line in enumerate(f):
                if i % 100000 == 0:
                    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + " - " + str(i), file=sys.stderr)
                for ngram in Windower(line, n):
                    confusible = ngram[l]
                    if confusible in self.settings["confusibles"]:
                        if self.hapaxer:
                            ngram = self.hapaxer(ngram)
                        leftcontext = tuple(ngram[:l])
                        rightcontext = tuple(ngram[l + 1 :])
                        classifier.append(leftcontext + rightcontext, confusible)

        self.log("Training classifier...")
        classifier.train()

        self.log("Saving model " + modelfile)
        classifier.save()
Exemplo n.º 7
0
    def train(self, sourcefile, modelfile, **parameters):
        if modelfile == self.confusiblefile:
            #Build frequency list
            self.log(
                "Preparing to generate lexicon for suffix confusible module")
            classfile = stripsourceextensions(sourcefile) + ".cls"
            corpusfile = stripsourceextensions(sourcefile) + ".dat"

            if not os.path.exists(classfile):

                self.log("Building class file")
                classencoder = colibricore.ClassEncoder(
                    "", self.settings['minlength'],
                    self.settings['maxlength'])  #character length constraints
                classencoder.build(sourcefile)
                classencoder.save(classfile)
            else:
                classencoder = colibricore.ClassEncoder(
                    classfile, self.settings['minlength'],
                    self.settings['maxlength'])

            if not os.path.exists(corpusfile):
                self.log("Encoding corpus")
                classencoder.encodefile(sourcefile, corpusfile)

            self.log("Generating frequency list")
            options = colibricore.PatternModelOptions(
                mintokens=self.settings['freqthreshold'],
                minlength=1,
                maxlength=1)  #unigrams only
            model = colibricore.UnindexedPatternModel()
            model.train(corpusfile, options)

            self.log("Finding confusible pairs")
            classdecoder = colibricore.ClassDecoder(classfile)
            self.confusibles = []  #pylint: disable=attribute-defined-outside-init
            for pattern in model:
                try:
                    pattern_s = pattern.tostring(classdecoder)
                except UnicodeDecodeError:
                    self.log(
                        "WARNING: Unable to decode a pattern in the model!!! Invalid utf-8!"
                    )
                for suffix in self.suffixes:
                    if pattern_s.endswith(
                            suffix) and not pattern_s in self.confusibles:
                        found = []
                        for othersuffix in self.suffixes:
                            if othersuffix != suffix:
                                otherpattern_s = pattern_s[:-len(
                                    suffix)] + othersuffix
                                try:
                                    otherpattern = classencoder.buildpattern(
                                        otherpattern_s, False, False)
                                except KeyError:
                                    if found: found = []
                                    break
                                if not otherpattern in model:
                                    if found: found = []
                                    break
                                if self.settings['maxratio'] != 0:
                                    freqs = (
                                        model.occurrencecount(pattern),
                                        model.occurrencecount(otherpattern))
                                    ratio = max(freqs) / min(freqs)
                                    if ratio < self.settings['maxratio']:
                                        if found: found = []
                                        break
                                found.append(otherpattern_s)
                        if found:
                            self.confusibles.append(pattern_s)
                            for s in found:
                                self.confusibles.append(s)

            self.log("Writing confusible list")
            with open(modelfile, 'w', encoding='utf-8') as f:
                for confusible in self.confusibles:
                    f.write(confusible + "\n")

        elif modelfile == self.modelfile:
            try:
                self.confusibles
            except AttributeError:
                self.confusibles = []
                self.log("Loading confusiblefile")
                with open(self.confusiblefile, 'r', encoding='utf-8') as f:
                    for line in f:
                        line = line.strip()
                        if line:
                            self.confusibles.append(line)

            if self.hapaxer:
                self.log("Training hapaxer...")
                self.hapaxer.train()

            l = self.settings['leftcontext']
            r = self.settings['rightcontext']
            n = l + 1 + r

            self.log("Generating training instances...")
            fileprefix = modelfile.replace(".ibase",
                                           "")  #has been verified earlier
            classifier = TimblClassifier(fileprefix, self.gettimbloptions())
            if sourcefile.endswith(".bz2"):
                iomodule = bz2
            elif sourcefile.endswith(".gz"):
                iomodule = gzip
            else:
                iomodule = io
            with iomodule.open(sourcefile,
                               mode='rt',
                               encoding='utf-8',
                               errors='ignore') as f:
                for i, line in enumerate(f):
                    for ngram in Windower(line, n):
                        if i % 100000 == 0:
                            print(datetime.datetime.now().strftime(
                                "%Y-%m-%d %H:%M:%S") + " - " + str(i),
                                  file=sys.stderr)
                        confusible = ngram[l]
                        if confusible in self.confusibles:
                            if self.hapaxer:
                                ngram = self.hapaxer(ngram)
                            leftcontext = tuple(ngram[:l])
                            rightcontext = tuple(ngram[l + 1:])
                            suffix, normalized = self.getsuffix(confusible)
                            if suffix is not None:
                                classifier.append(
                                    leftcontext + (normalized, ) +
                                    rightcontext, suffix)

            self.log("Training classifier...")
            classifier.train()

            self.log("Saving model " + modelfile)
            classifier.save()
Exemplo n.º 8
0
    def train(self, sourcefile, modelfile, **parameters):
        if modelfile == self.confusiblefile:
            #Build frequency list
            self.log("Preparing to generate lexicon for suffix confusible module")
            classfile = stripsourceextensions(sourcefile) +  ".cls"
            corpusfile = stripsourceextensions(sourcefile) +  ".dat"

            if not os.path.exists(classfile):

                self.log("Building class file")
                classencoder = colibricore.ClassEncoder("", self.settings['minlength'], self.settings['maxlength']) #character length constraints
                classencoder.build(sourcefile)
                classencoder.save(classfile)
            else:
                classencoder = colibricore.ClassEncoder(classfile, self.settings['minlength'], self.settings['maxlength'])


            if not os.path.exists(corpusfile):
                self.log("Encoding corpus")
                classencoder.encodefile( sourcefile, corpusfile)

            self.log("Generating frequency list")
            options = colibricore.PatternModelOptions(mintokens=self.settings['freqthreshold'],minlength=1,maxlength=1) #unigrams only
            model = colibricore.UnindexedPatternModel()
            model.train(corpusfile, options)


            self.log("Finding confusible pairs")
            classdecoder = colibricore.ClassDecoder(classfile)
            self.confusibles = [] #pylint: disable=attribute-defined-outside-init
            for pattern in model:
                try:
                    pattern_s = pattern.tostring(classdecoder)
                except UnicodeDecodeError:
                    self.log("WARNING: Unable to decode a pattern in the model!!! Invalid utf-8!")
                for suffix in self.suffixes:
                    if pattern_s.endswith(suffix) and not pattern_s in self.confusibles:
                        found = []
                        for othersuffix in self.suffixes:
                            if othersuffix != suffix:
                                otherpattern_s = pattern_s[:-len(suffix)] + othersuffix
                                try:
                                    otherpattern = classencoder.buildpattern(otherpattern_s,False,False)
                                except KeyError:
                                    if found: found = []
                                    break
                                if not otherpattern in model:
                                    if found: found = []
                                    break
                                if self.settings['maxratio'] != 0:
                                    freqs = (model.occurrencecount(pattern), model.occurrencecount(otherpattern))
                                    ratio = max(freqs) / min(freqs)
                                    if ratio < self.settings['maxratio']:
                                        if found: found = []
                                        break
                                found.append(otherpattern_s )
                        if found:
                            self.confusibles.append(pattern_s)
                            for s in found:
                                self.confusibles.append(s)

            self.log("Writing confusible list")
            with open(modelfile,'w',encoding='utf-8') as f:
                for confusible in self.confusibles:
                    f.write(confusible + "\n")

        elif modelfile == self.modelfile:
            if self.hapaxer:
                self.log("Training hapaxer...")
                self.hapaxer.train()

            l = self.settings['leftcontext']
            r = self.settings['rightcontext']
            n = l + 1 + r

            self.log("Generating training instances...")
            fileprefix = modelfile.replace(".ibase","") #has been verified earlier
            classifier = TimblClassifier(fileprefix, self.gettimbloptions())
            if sourcefile.endswith(".bz2"):
                iomodule = bz2
            elif sourcefile.endswith(".gz"):
                iomodule = gzip
            else:
                iomodule = io
            with iomodule.open(sourcefile,mode='rt',encoding='utf-8', errors='ignore') as f:
                for i, line in enumerate(f):
                    for ngram in Windower(line, n):
                        if i % 100000 == 0: print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + " - " + str(i),file=sys.stderr)
                        confusible = ngram[l]
                        if confusible in self.confusibles:
                            if self.hapaxer:
                                ngram = self.hapaxer(ngram)
                            leftcontext = tuple(ngram[:l])
                            rightcontext = tuple(ngram[l+1:])
                            suffix, normalized = self.getsuffix(confusible)
                            if suffix is not None:
                                classifier.append( leftcontext + (normalized,) + rightcontext , suffix )

            self.log("Training classifier...")
            classifier.train()

            self.log("Saving model " + modelfile)
            classifier.save()
Exemplo n.º 9
0
    def train(self, sourcefile, modelfile, **parameters):
        if self.hapaxer:
            self.log("Training hapaxer...")
            self.hapaxer.train()
        if modelfile.endswith('.ibase'):
            l = self.settings['leftcontext']
            r = self.settings['rightcontext']
            n = l + 1 + r

            self.log("Generating training instances...")
            fileprefix = modelfile.replace(".ibase","") #has been verified earlier
            classifier = TimblClassifier(fileprefix, self.gettimbloptions())
            if sourcefile.endswith(".bz2"):
                iomodule = bz2
            elif sourcefile.endswith(".gz"):
                iomodule = gzip
            else:
                iomodule = io
            with iomodule.open(sourcefile,mode='rt',encoding='utf-8') as f:
                for i, line in enumerate(f):
                    if i % 100000 == 0: print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + " - " + str(i),file=sys.stderr)
                    for ngram in Windower(line, n):
                        if self.hapaxer:
                            ngram = self.hapaxer(ngram)
                        focus = ngram[l]
                        if self.hapaxer and focus == self.hapaxer.placeholder:
                            continue
                        leftcontext = tuple(ngram[:l])
                        rightcontext = tuple(ngram[l+1:])
                        classifier.append( leftcontext + rightcontext , focus )

            self.log("Training classifier...")
            classifier.train()

            self.log("Saving model " + modelfile)
            classifier.save()
        elif modelfile.endswith('.patternmodel'):
            self.log("Preparing to generate lexicon for Language Model")
            classfile = stripsourceextensions(sourcefile) +  ".cls"
            corpusfile = stripsourceextensions(sourcefile) +  ".dat"

            if not os.path.exists(classfile):
                self.log("Building class file")
                classencoder = colibricore.ClassEncoder()
                classencoder.build(sourcefile)
                classencoder.save(classfile)
            else:
                classencoder = colibricore.ClassEncoder(classfile)

            if not os.path.exists(modelfile+'.cls'):
                #make symlink to class file, using model name instead of source name
                os.symlink(classfile, modelfile + '.cls')

            if not os.path.exists(corpusfile):
                self.log("Encoding corpus")
                classencoder.encodefile( sourcefile, corpusfile)

            if not os.path.exists(modelfile+'.cls'):
                #make symlink to class file, using model name instead of source name
                os.symlink(classfile, modelfile + '.cls')

            self.log("Generating pattern model")
            options = colibricore.PatternModelOptions(mintokens=self.settings['freqthreshold'],minlength=1,maxlength=1)
            model = colibricore.UnindexedPatternModel()
            model.train(corpusfile, options)

            self.log("Saving model " + modelfile)
            model.write(modelfile)
Exemplo n.º 10
0
from timbl import TimblClassifier

classifier = TimblClassifier('test','-a 0 -k 1 +vk')

classifier.append( ('dit','is','een'), 'idee')
classifier.append( ('dat','was','geen'), 'doen')

classifier.train()

r = classifier.classify(('dit','was','geen'))
print(r)
Exemplo n.º 11
0
    def train(self, sourcefile, modelfile, **parameters):
        if self.hapaxer:
            self.log("Training hapaxer...")
            self.hapaxer.train()
        if modelfile.endswith('.ibase'):
            l = self.settings['leftcontext']
            r = self.settings['rightcontext']
            n = l + 1 + r

            self.log("Generating training instances...")
            fileprefix = modelfile.replace(".ibase","") #has been verified earlier
            classifier = TimblClassifier(fileprefix, self.gettimbloptions())
            if sourcefile.endswith(".bz2"):
                iomodule = bz2
            elif sourcefile.endswith(".gz"):
                iomodule = gzip
            else:
                iomodule = io
            with iomodule.open(sourcefile,mode='rt',encoding='utf-8') as f:
                for i, line in enumerate(f):
                    if i % 100000 == 0: print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + " - " + str(i),file=sys.stderr)
                    for ngram in Windower(line, n):
                        if self.hapaxer:
                            ngram = self.hapaxer(ngram)
                        focus = ngram[l]
                        if self.hapaxer and focus == self.hapaxer.placeholder:
                            continue
                        leftcontext = tuple(ngram[:l])
                        rightcontext = tuple(ngram[l+1:])
                        classifier.append( leftcontext + rightcontext , focus )

            self.log("Training classifier...")
            classifier.train()

            self.log("Saving model " + modelfile)
            classifier.save()
        elif modelfile.endswith('.patternmodel'):
            self.log("Preparing to generate lexicon for Language Model")
            classfile = stripsourceextensions(sourcefile) +  ".cls"
            corpusfile = stripsourceextensions(sourcefile) +  ".dat"

            if not os.path.exists(classfile):
                self.log("Building class file")
                classencoder = colibricore.ClassEncoder()
                classencoder.build(sourcefile)
                classencoder.save(classfile)
            else:
                classencoder = colibricore.ClassEncoder(classfile)

            if not os.path.exists(modelfile+'.cls'):
                #make symlink to class file, using model name instead of source name
                os.symlink(classfile, modelfile + '.cls')

            if not os.path.exists(corpusfile):
                self.log("Encoding corpus")
                classencoder.encodefile( sourcefile, corpusfile)

            if not os.path.exists(modelfile+'.cls'):
                #make symlink to class file, using model name instead of source name
                os.symlink(classfile, modelfile + '.cls')

            self.log("Generating pattern model")
            options = colibricore.PatternModelOptions(mintokens=self.settings['freqthreshold'],minlength=1,maxlength=1)
            model = colibricore.UnindexedPatternModel()
            model.train(corpusfile, options)

            self.log("Saving model " + modelfile)
            model.write(modelfile)
Exemplo n.º 12
0
class skTiMBL(BaseEstimator, ClassifierMixin):
    def __init__(self, prefix='timbl', algorithm=4, dist_metric=None,
                 k=1,  normalize=False, debug=0, flushdir=None):
        self.prefix = prefix
        self.algorithm = algorithm
        self.dist_metric = dist_metric
        self.k = k
        self.normalize = normalize
        self.debug = debug
        self.flushdir = flushdir


    def _make_timbl_options(self, *options):
        """
        -a algorithm
        -m metric
        -w weighting
        -k amount of neighbours
        -d class voting weights
        -L frequency threshold
        -T which feature index is label
        -N max number of features
        -H turn hashing on/off

        This function still has to be made, for now the appropriate arguments
        can be passed in fit()
        """
        pass


    def fit(self, X, y):
        X, y = check_X_y(X, y, dtype=np.int64, accept_sparse='csr')

        n_rows = X.shape[0]
        self.classes_ = np.unique(y)

        if sp.sparse.issparse(X):
            if self.debug: print('Features are sparse, choosing faster learning')

            self.classifier = TimblClassifier(self.prefix, "-a{} -k{} -N{} -vf".format(self.algorithm,self.k, X.shape[1]),
                                              format='Sparse', debug=True, sklearn=True, flushdir=self.flushdir,
                                              flushthreshold=20000, normalize=self.normalize)

            for i in range(n_rows):
                sparse = ['({},{})'.format(i+1, c) for i,c in zip(X[i].indices, X[i].data)]
                self.classifier.append(sparse,str(y[i]))

        else:

            self.classifier = TimblClassifier(self.prefix, "-a{} -k{} -N{} -vf".format(self.algorithm, self.k, X.shape[1]),
                                              debug=True, sklearn=True, flushdir=self.flushdir, flushthreshold=20000,
                                              normalize=self.normalize)

            if y.dtype != 'O':
                y = y.astype(str)

            for i in range(n_rows):
                self.classifier.append(list(X[i].toarray()[0]), y[i])

        self.classifier.train()
        return self


    def _timbl_predictions(self, X, part_index, y=None):
        choices = {0 : lambda x : x.append(np.int64(label)),
                   1 : lambda x : x.append([np.float(distance)]),
                  }
        X = check_array(X, dtype=np.float64, accept_sparse='csr')

        n_samples = X.shape[0]

        pred = []
        func = choices[part_index]
        if sp.sparse.issparse(X):
            if self.debug: print('Features are sparse, choosing faster predictions')

            for i in range(n_samples):
                sparse = ['({},{})'.format(i+1, c) for i,c in zip(X[i].indices, X[i].data)]
                label,proba, distance = self.classifier.classify(sparse)
                func(pred)

        else:
            for i in range(n_samples):
                label,proba, distance = self.classifier.classify(list(X[i].toarray()[0]))
                func(pred)

        return np.array(pred)



    def predict(self, X, y=None):
        return self._timbl_predictions(X, part_index=0)


    def predict_proba(self, X, y=None):
        """
        TIMBL is a discrete classifier. It cannot give probability estimations.
        To ensure that scikit-learn functions with TIMBL (and especially metrics
        such as ROC_AUC), this method is implemented.

        For ROC_AUC, the classifier corresponds to a single point in ROC space,
        instead of a probabilistic continuum such as classifiers that can give
        a probability estimation (e.g. Linear classifiers). For an explanation,
        see Fawcett (2005).
        """
        return predict(X)


    def decision_function(self, X, y=None):
        """
        The decision function is interpreted here as being the distance between
        the instance that is being classified and the nearest point in k space.
        """
        return self._timbl_predictions(X, part_index=1)