Пример #1
0
Файл: ss.py Проект: awd4/spnss
    def __init__(self, net, trn, vld, threshold):
        assert trn.ndim == 2
        self.net         = net      # IMPORTANT: we assume that 'net' has already been trained using 'trn'
        self.ssdata      = SSData(net, trn)

        self.net_history = NetHistory(net, vld)

        self.thresh      = threshold
Пример #2
0
Файл: ss.py Проект: awd4/spnss
class SS:

    def __init__(self, net, trn, vld, threshold):
        assert trn.ndim == 2
        self.net         = net      # IMPORTANT: we assume that 'net' has already been trained using 'trn'
        self.ssdata      = SSData(net, trn)

        self.net_history = NetHistory(net, vld)

        self.thresh      = threshold

    def print_stats(self, i=None, name=''):
        if i is not None:
            logging.warning('\t================ %d =============== %s' % (i, name))
        msg = '\tsize: %d' % len(self.net.pot)
        msg += '\tvld: %f' % self.net_history.vld_hist[-1]
        msg += '\tthresh: %f' % self.thresh
        logging.warning(msg)

    def step(self):
        ''' Take a step in the search space of SPN graphs. '''

        pn, chll = select_step(self.ssdata, self.thresh)
        if (pn, chll) == (None, None):
            return False
        mol = [ops.MixOp(self.net, pn, chl) for chl in chll]

        for mo in mol:
            # cluster
            data = np.hstack( self.ssdata.edges[c][:,np.newaxis] for c in mo.chl ); assert len(data) >= knobs.min_instances
            nvals = [len(c.weights) if c.is_sum() else len(c.masses) for c in mo.chl]
            qa, nc = nbc.kcluster(data, nvals)
            #qa, nc = nbc.inc_hard_em(data, nvals)

            # change the graph
            mo.connect(qa.max()+1)
            ops.compute_params(mo, qa, self.ssdata.edges, knobs.laplace_smooth)
            ops.adjust_edges(mo, qa, self.ssdata.edges)
            self.ssdata.update_scores_of( [n for n in mo.prod_nodes()] )

        self.net_history.add_op_node_list(mol)

        logging.info('\tncl: %s' % str([len(mo.pnl) for mo in mol]))

        return True

    def step_ahead(self, num_steps):
        for j in xrange(num_steps):
            if self.step() == False:
                logging.warning('\tran out of steps to take.')
                j -= 1
                break
        self.net_history.save_vld()
        return j+1

    def skip_search(self, num_steps):
        ''' This is the main structure search algorithm. '''
        assert num_steps >= 1

        nh = self.net_history

        i0 = 0
        i1 = i0 + self.step_ahead(num_steps)
        i2 = i1 + self.step_ahead(num_steps)
        if nh.vh[i1] <= nh.vh[i0]:
            bni = nh.best_net_index(i0, i1)
            nh.move_to(bni)
            assert self.net == nh.net
            return nh.net

        MAX_STEPS = 100000000
        for i in xrange(MAX_STEPS):
            self.print_stats(i)
            logging.info('\t%f' % nh.vh[-1])
            if nh.vh[i2] <= nh.vh[i1]:
                break
            taken = self.step_ahead(num_steps)
            nh.vh[i0] = None
            i0 = i1
            i1 = i2
            i2 += taken
        if i == MAX_STEPS-1:
            raise Exception, 'skip_search() just took a HUGE number of steps; that is not expected.'

        bni = nh.best_net_index(i0, i2)
        nh.move_to(bni)
        assert self.net == nh.net
        return nh.net