Exemplo n.º 1
0
    def _run(self):
        # supress output
        #so = file('/dev/null', 'a+')
        #stdout = os.dup(sys.stdout.fileno())
        #os.dup2(so.fileno(), sys.stdout.fileno())

        self.max_score = -1
        attrs_selected = self.attrs_selected = []
        attrs_left = range(self.num_attr)
        cls_node = self.num_attr

        _stop = self._stop
        while len(attrs_selected) < self.num_attr and not _stop():
            pick = -1
            #p = 0.3
            for i,a in enumerate(attrs_left):
                tmp = attrs_selected + [a, cls_node]
                tmp.sort()
                #local_cpd_cache_ = SharedLocalCPDCache(self._cpd_cache._cpd_cache, tmp)
                score = cross_validate(self.data.subset(tmp), 
                                       #local_cpd_cache=local_cpd_cache_,
                                       score_type='WC',
                                       runs=10)
                if score > self.max_score:
                    self.max_score = score
                    pick = i
                    #if random() < p: break
            if pick == -1: break
            attrs_selected.append(attrs_left.pop(pick))   
            
        attrs_selected.sort()
        attrs_selected.append(cls_node)
        #self.attrs_selected = attrs_selected
        tan_learner = TANClassifierLearner(self.data.subset(attrs_selected),
                        local_cpd_cache=SharedLocalCPDCache(self._cpd_cache._cpd_cache, attrs_selected))
        tan_learner.run()
        self.network = tan_learner.network
        self.cpd = tan_learner.cpd
        self._learner = tan_learner
        self.result.add_network(self.network, self.max_score)
Exemplo n.º 2
0
 def _crossValidateScoreFunc(self, subset_idx, **cvargs):
     data = self.data.subset(subset_idx)
     return cross_validate(data, classifier_type=self.classifier_type, **cvargs)
Exemplo n.º 3
0
parser.set_defaults(classifier_type="tan")

parser.add_option("-r", dest="test_ratio", type="float", help="Ratio of test data")
parser.set_defaults(test_ratio=0.3)

parser.add_option("-t", dest="runs", type="int", help="Number of runs")
parser.set_defaults(runs=10)

parser.add_option("-d", dest="numbins", type="int", help="Number of bins in discretizing")
parser.set_defaults(numbins=0)

parser.add_option("-v", dest="verbose", action="store_true", help="Report verbosely")
parser.set_defaults(verbose=False)

parser.add_option("-s", dest="score_type", help="Score type")
parser.set_defaults(score_type="BA")

(options, args) = parser.parse_args()

if len(args) < 1:
    print "You did't specify a data file"
    print
    sys.exit(1)
else:
    datafile = args[0]
    dataset = data.fromfile(datafile)
    if options.numbins:
        dataset.discretize(numbins=options.numbins, excludevars=[dataset.variables.size-1])
    cross_validate(dataset, options.classifier_type, options.test_ratio, options.runs, options.verbose, options.score_type)