class Test_ObjFramework_2_records_same_cls(unittest.TestCase):

    def setUp(self):
        self.of = ObjFactory()
        self.obj1= self.of.new(GenericBase,
                               'Student',
                               objid='booker',
                               modname=__name__,
                               name='booker',
                               age=23)
        
        self.obj2= self.of.new(GenericBase,
                               'Student',
                               objid='frank',
                               modname=__name__,
                               name='frank',
                               age=19)
        

    def tearDown(self):
        self.of.reset()
        
    def test_2records_same_class(self):
        names = [obj.name for obj in self.of.query('Student')]
        names.sort()
        self.assertEquals(names,['booker','frank'])
class Test_ObjFramework_Database(unittest.TestCase):


    def setUp(self):
        self.of = ObjFactory(True)
        self.database = Database('foobar')
        self.foobar= self.of.new(dbtblgeneric,
                                 'DBLesson',
                                 objid='dblesson0',
                                 constructor='datamembers',
                                 modname=__name__,
                                 database=self.database,
                                 dm={'student':'booker',
                                     'period':2,
                                     'dow':3})
        
    def tearDown(self):
        self.of.reset()

    def test_num_obj_created(self):
        self.assertEquals(len(self.of.query('DBLesson')),1)
        
    def test_correct_keys_created(self):
        self.assertTrue(self.of.object_exists('DBLesson','dblesson0'))

    
    def test_objects_created_stored(self):
        _lesson = self.of.object_get('DBLesson','dblesson0')
        self.assertEquals(_lesson.__class__.__name__,"DBLesson")

    def test_objects_have_attributes(self):
        _lesson = self.of.object_get('DBLesson','dblesson0')  
        self.assertEquals(_lesson.student,'booker')
        self.assertEquals(_lesson.period,2)
        self.assertEquals(_lesson.dow,3)
class Test_ObjFrameworkIter(unittest.TestCase):

    def setUp(self):
        self.of = ObjFactory(True)
        self.of.new(GenericBase,
                    'Student',
                    objid='booker',
                    modname=__name__)
        
        self.of.new(GenericBase,
                    'Student',
                    objid='fred',
                    modname=__name__)
        
        self.of.new(GenericBase,
                    'Classroom',
                    objid='1a',
                    modname=__name__)
        

    def tearDown(self):
        self.of.reset()
        
    def test_iter(self):
        result = [obj.objid for obj in self.of.object_iter()]
        result.sort()
        
        self.assertListEqual(result,['1a','booker','fred'])
class Test_ObjFrameworkBasic(unittest.TestCase):

    def setUp(self):
        self.of = ObjFactory()

        foobar= self.of.new(GenericBase,
                            "Student",
                            objid='booker',
                            modname=__name__,
                            name='booker',
                            age=23)      

    def tearDown(self):
        self.of.reset()
        
    def test_num_obj_created(self):
        self.assertEquals(len(ObjFactory().store['Student']),1)
        
    def test_correct_keys_created(self):
        self.assertTrue(ObjFactory().store['Student'].has_key('booker'))
    
    def test_objects_created_stored(self):
        _student = ObjFactory().store['Student']['booker']
        self.assertEquals(_student.__class__.__name__,"Student")

    def test_objects_have_attributes(self):
        _student = ObjFactory().store['Student']['booker']        
        self.assertEquals(_student.name,'booker')
        self.assertEquals(_student.age,23)
class Test_ObjFrameworkDumpNested(unittest.TestCase):

    def setUp(self):
        self.of = ObjFactory(True)
        self.obj1 = self.of.new(GenericBase,
                    'Student',
                    objid='booker',
                    nationality='british',
                    modname=__name__)
        
        self.of.new(GenericBase,
                    'Student',
                    objid='fred',
                    age=23,
                    nationality='british',
                    modname=__name__)
        
        self.of.new(GenericBase,
                    'Student',
                    objid='fred',
                    age=35,
                    nationality='irish',
                    modname=__name__)
        
        self.of.new(GenericBase,
                    'Classroom',
                    objid='1a',
                    nationality='swedish',
                    modname=__name__)
        

    def tearDown(self):
        self.of.reset()
        
    '''def test_1clause(self):
        results = self.of.query_advanced('Student',[('objid','booker')])
        
        self.assertEquals(len(results),1)
        self.assertEquals(results[0].objid,'booker')
        
    def test_2clause(self):
        results = self.of.query_advanced('Student',[('nationality','british'),
                                                    ('objid','fred')])
        
        self.assertEquals(len(results),1)
        self.assertEquals(results[0].age,23)'''
        
        
    def test_update_then_search(self):
        ''' make sure that search picks up the updated version of the object '''
        
        self.obj1.nationality = 'indian'
        results = self.of.query_advanced('Student',[('objid','booker')])
        
        self.assertEquals(results[0].nationality,'indian')
class Test_ObjFrameworkDupeID(unittest.TestCase):

    def setUp(self):
        self.of = ObjFactory()
        self.obj1= self.of.new(GenericBase,
                               'Student',
                               objid='booker',
                               modname=__name__)
        
        self.obj2= self.of.new(GenericBase,
                               'Student',
                               objid='booker',
                               modname=__name__)
        

    def tearDown(self):
        self.of.reset()
        
    def test_num_dupe_objid(self):
        self.assertEqual(self.obj1,self.obj2)
class Test_ObjFramework_2_class(unittest.TestCase):

    def setUp(self):
        self.of = ObjFactory()
        self.obj1= self.of.new(GenericBase,
                               'Student',
                               objid='booker',
                               modname=__name__,
                               name='booker',
                               age=23)
        
        self.obj2= self.of.new(GenericBase,
                               'Subject',
                               objid='science',
                               modname=__name__,
                               name='science',
                               teacher_name='fran')
        

    def tearDown(self):
        self.of.reset()
        
    def test_2_class(self):
        self.assertListEqual(self.of.query(),['Student','Subject'])
class SSViewer(object):

    def __init__(self,dbname,refdbname):
        
        log.log(thisfuncname(),3,msg="initialize",dbname=dbname,refdbname=refdbname)

        self.colorpalette = dict(wp=green,subject=lightblue,ap='yellow',
                                 Movement=pink,ELA=salmon,Humanities=lightyellow,
                                 Counseling=lightgreen,Math=lightturquoise, 
                                 Music=lightblue,STEM=lavender,Art=purple,History=pink,
                                 Science=darkgreen,Core=karky,Chess=burgundy,
                                 computertime='darkgrey',Speech=darkburgundy,
                                 Student_News=darkgrey,Computer_Time=brown,
                                 Activity_Period=mauve,Melissa=navyblue,Amelia=darkgreen,
                                 Samantha=darkyellow, Alexa=paleblue, Paraic=palegreen, 
                                 Francisco=cerise,Rahul=verydarkgrey,Dylan=verydarkgrey,
                                 Moira=verydarkgrey,Issey=verydarkgrey, Daryl=verydarkgrey, 
                                 Karolina=verydarkgrey)
        

        self.fontpalette = dict(Amelia=green,Paraic=darkgreen,Stan=lavender,
                                Samantha=lightgreen,Alexa=blue,Francisco=purple,
                                Melissa=lightblue,Rahul=dirtyyellow,Dylan=dirtyyellow, 
                                Moira=dirtyyellow,Issey=dirtyyellow, Daryl=dirtyyellow, 
                                Karolina=dirtyyellow,Chess=pink,Student_News=lightyellow,
                                subject=blue)
        
        
        self.of = ObjFactory(True)
        
        self.refdatabase = Database(refdbname)        
        self.dbname = dbname
        self.database = Database(self.dbname)
        self.lastsaveversion=0

    def _color_get_multi(self,values):
        bgs=[]
        fgs=[]
        for value in values:
            _bg,_fg = self.color_get(value)
            bgs.append(_bg)
            fgs.append(_fg)
        return(bgs,fgs)
    
    def color_get(self,value):
        
        bg = lightgrey
        fg = black
            
        try:
            int(value)
            value = str(value)
        except ValueError:
            pass
        
        if value.count(" ") > 0:
            value= value.replace(" ","_")
            
        if value.count("[") == 1 and value.count("]") == 1:
            bg = red
        
        if value.count(".") > 0:
            value = value.split(".")[0]
            
        if self.colorpalette.has_key(value):
            bg = self.colorpalette[value]
            
        if self.fontpalette.has_key(value):
            fg = self.fontpalette[value]
            
            
        return(bg,fg)        
        
    def viewer(self,yaxis_type,xaxis_type,ztypes, source_type,source_value,
               conflicts_only='N',constraints=None,wratio=None,formatson=False):
            
        if source_value == "":
            source_objs = self.of.query(source_type)
        else:
            source_objs = [self.of.object_get(source_type,source_value)]
            
        xaxis_obj = self.of.query(xaxis_type)
        yaxis_obj = self.of.query(yaxis_type)
        
        count=0
        yaxis_enum = {}
        for _yaxis_obj in yaxis_obj:
            yaxis_enum[_yaxis_obj.name] = count
            count+=1
        
        xaxis_enum = self.enums[xaxis_type]['name2enum']
        
        values = [] # contains the values displayed on the grid

        values = [['']]    
        for yval in yaxis_enum.keys():
            values[0].append(yval)
            
        for xval in xaxis_enum.keys():
            values.append([xval])

        ymax = len(values[0])
        xmax = len(values)-1
        
        def _additem(celltext,item):
            
            if len(celltext) == 0:
                celltext.append(item)
            else:
                try:
                    celltext.index(item)
                except:
                    celltext.append(item)
            return(celltext)
                
        for yval,y in yaxis_enum.iteritems():
            
            for xval,x in xaxis_enum.iteritems():
                celltext=[]
                
                for source_obj in source_objs:
                    if source_obj.lessons.has_key(yval):
                        if source_obj.lessons[yval].has_key(xval):
   
                            _vals = source_obj.lessons[yval][xval]

                            for _val in _vals:
                                
                                if constraints <> None:
                                    flag=False
                                    for objtype,objval in constraints:
                                        
                                        if getattr(_val,objtype).name <> objval:
                                            flag=True
                                    if flag == True:
                                        continue
                                    
                                if ztypes == ['*']:
                                    if celltext == []:
                                        celltext.append(0)
                                    else:
                                        celltext[0] = celltext[0] + 1
                                    continue

                                _celltext = []
                                
                                for ztype in ztypes:
                                    if hasattr(_val,ztype) == True:
                                        zval = getattr(_val,ztype)
                                        
                                        try:
                                            _celltext.index(zval.name)
                                        except:
                                            _celltext = _additem(_celltext,zval.name)
                                            pass
                                            
                                try:      
                                    celltext.index(tuple(_celltext))
                                except:
                                    celltext.append(tuple(_celltext))
                                    
                values[x].append(celltext)
        
        sswizard_utils.gridreduce(values,[[]])

        if formatson==True:                
            for x in range(len(values)):
                for y in range(len(values[x])):
                    _value = values[x][y]

                    if isinstance(_value,list) == True:
                        if _value <> []:
                            values[x][y] = []
                            ''' uncomment if want to generate conflicts report'''
                            #if len(_value) == 1 and conflicts_only <> "Y":
                            if len(_value) == 1:
                                if isinstance(_value[0],tuple) == True:
                                    # 1 item, multi attributes
                                    bgs,fgs = self._color_get_multi(_value[0])
                                    _formats = []
                                    for i in range(len(_value[0])):
                                        _formats.append(dict(value=_value[0][i],bgcolor=bgs[i],fgcolor=fgs[i]))
                                        
                                    values[x][y].append(tuple(_formats))
                                elif isinstance(_value[0],list) == False:
                                    # 1 item, single value
                                    bg,fg = self.color_get(_value[0])

                                    print "1,1 attr",_value[0],bgs,fgs
                            # multiple items
                            ''' uncomment if want to generate conflicts report'''
                            #if len(_value) > 1 and conflicts_only == "Y":
                            if len(_value) > 1:
                                for __value in _value:
                                    bgs,fgs = self._color_get_multi(__value)
                                    _formats = []
                                    for i in range(len(__value)):
                                        _formats.append(dict(value=__value[i],bgcolor=bgs[i],fgcolor=fgs[i]))
                                        
                                    values[x][y].append(tuple(_formats))
                    else:
                        if x == 0 or y == 0:
                            values[x][y] = dict(value=_value,bgcolor=black,fgcolor=white)
                                
            
            return values

    def lesson_change(self,lesson):

        def _add(obj,xtype,ytype,lesson):
            
            xtype_id = getattr(lesson,xtype).objid
            ytype_id = getattr(lesson,ytype).objid
            
            # indexed by dow/period
            if obj.lessons.has_key(xtype_id) == False:
                obj.lessons[xtype_id] = {} 
    
            if obj.lessons[xtype_id].has_key(ytype_id) == False:
                obj.lessons[xtype_id][ytype_id] = []
                
            obj.lessons[xtype_id][ytype_id].append(lesson)
        
        adult = lesson.adult
        student = lesson.student

        # add the lesson to the adult object        
        if hasattr(adult,'lessons') == False:
            setattr(adult,'lessons',{})
            
        _add(adult,'dow','period',lesson) # indexed by dow/period
        _add(adult,'student','period',lesson) # indexed by student/period

        # add the lesson to the student object
        if hasattr(student,'lessons') == False:
            setattr(student,'lessons',{})
            
        _add(student,'dow','period',lesson) # indexed by dow/period
        _add(student,'adult','period',lesson) # indexed by adult/period
        _add(student,'period','recordtype',lesson) # indexed by adult/period
        _add(student,'student','recordtype',lesson) # indexed by adult/period
        

    @logger(log)       
    def load(self,saveversion, dow=None, prep=None, period=None, teacher=None, student=None, source=None,
             unknown='N'):
        
        self.of.reset()
            
        whereclause = []
        
        # unknown
        if unknown=='N':
            whereclause.append(['student',"<>","\"??\""])
            whereclause.append(['subject',"<>","\"??\""])
            whereclause.append(['teacher',"<>","\"??\""])
        log.log(thisfuncname(),3,msg="loading",unknown=str(unknown))
            
        # prep
        if prep==None:
            prep = -1
        else:
            whereclause.append(['prep',"=",prep])
        log.log(thisfuncname(),3,msg="loading",prep=str(prep))
        
        # period
        if period==None: 
            prep = -1
        else:
            whereclause.append(['period',"=","\""+period+"\""])            
        log.log(thisfuncname(),3,msg="loading",prep=str(prep))

        # dow
        if dow==None:
            dow = "all"
        else:
            whereclause.append( ['dow',"=","\""+dow+"\""])
        log.log(thisfuncname(),3,msg="loading",dow=str(dow))
        
        
        # teacher
        if teacher==None: 
            teacher = "all"
        else:
            whereclause.append( ['teacher',"=","\""+teacher+"\""])
        log.log(thisfuncname(),3,msg="loading",teacher=str(teacher))      
        
        # student
        if student==None: 
            student = "all"
        else:
            whereclause.append( ['student',"=","\""+student+"\""])
        log.log(thisfuncname(),3,msg="loading",student=str(student))
        
        # source
        if source==None: 
            source = "dbinsert"
        else:
            _sources = ["\"" + _source + "\"" for _source in source.split(",")]
            whereclause.append( ['source',"in","("+",".join(_sources)+")"])
        log.log(thisfuncname(),3,msg="loading",source=str(source))

        # get enums
        self.enums = sswizard_utils.setenums(dow,prep,self.refdatabase)

        # load from database
        cols = ['period','student','session','dow','teacher','subject','userobjid','status','substatus','recordtype','source']        
        with self.database:
            colndefn,rows,exec_str = tbl_rows_get(self.database,'lesson',cols,whereclause)
            
            log.log(thisfuncname(),9,msg="dbread",exec_str=exec_str)
        
        cols = ['period','student','session','dow','adult','subject','userobjid','status','substatus','recordtype','source']
        
        # parse rows
        for row in rows:
            datamembers = {}
            for i in range(len(cols)):
                datamembers[cols[i]] = row[i]
            
            _,lessontype_code,_,_ = datamembers['session'].split(".")
            #lessontype = self.enums['lessontype']['code2name'][lessontype_code]      
            datamembers['objtype'] = 'lesson'                               

            lesson = self.of.new(schoolschedgeneric,'lesson',objid=datamembers['userobjid'],
                                 constructor='datamembers',database=self.database,
                                 of=self.of,modname=__name__,dm=datamembers)
            
            self.lesson_change(lesson)
                            
            log.log(thisfuncname(),3,msg="loading row",dm=datamembers)

        # post log with results
        log.log(thisfuncname(),3,msg="db rows loaded",num=len(rows))        
        for i in range(len(cols)):
            log.log(thisfuncname(),3,msg="lesson obj created",num=len(self.of.store[cols[i]]))

    def updates_get(self,gridname,ignoreaxes=False):
        
        return(sswizard_utils.updates_get(self,gridname,ignoreaxes))
        
    def _lastsaveversion_get(self):
        
        try:
            with self.database:
            
                colndefn,rows = tbl_query(self.database,"select max(saveversion) from lesson")                   
            return(rows[0][0])
        except Exception:
            return(-1)
class Test_ObjFrameworkDumpRptNestedSchoolsched(unittest.TestCase):
    
    # same as above just with the school sched nested object
    # so each attr is another object of (not a string or int) that 
    # potentially needs to be accessed via accessors
    def setUp(self):
        self.of = ObjFactory(True)
        self.database = Database('foobar')
        
        datamembers = dict(period='830',
                           student='Booker',
                           dow='MO',
                           teacher='Amelia',
                           saveversion=0,
                           session='AM.AC.SC')

        self.foobar= self.of.new(schoolschedgeneric,
                                 'DBLesson',
                                 objid='dblesson0',
                                 constructor='datamembers',
                                 database=self.database,
                                 of=self.of,
                                 modname=__name__,
                                 dm=datamembers)
        
        
    def test_(self):
        from types import StringType,IntType, UnicodeType
        results = self.of.dumpobjrpt(objref=False)
        
        expected_results = [['ROOT', 'period'], 
                            ['ROOT', 'saveversion'], 
                            ['ROOT', 'dow'], 
                            ['dblesson0', 'dow'], 
                            ['dblesson0', 'period'], 
                            ['dblesson0', 'saveversion'], 
                            ['dblesson0', 'session'], 
                            ['dblesson0', 'student'], 
                            ['dblesson0', 'teacher'], 
                            ['ROOT', 'DBLesson'], 
                            ['ROOT', 'session'], 
                            ['ROOT', 'student'], 
                            ['ROOT', 'teacher']]
        
        self.assertListEqual(expected_results,results)

    def test_filter_objtype(self):
        from types import StringType,IntType, UnicodeType
        results = self.of.dumpobjrpt(objtypes=['DBLesson','student'],objref=False)
         
        expected_results = [['dblesson0', 'student'], 
                            ['ROOT', 'DBLesson'],
                            ['ROOT', 'student']]
        
        
        expected_results.sort()
        results.sort()
        
        self.assertListEqual(expected_results,results)
        
    def test_filter_objtype_3items(self):
        from types import StringType,IntType, UnicodeType
        results = self.of.dumpobjrpt(objtypes=['DBLesson','student','dow'],objref=False)
         
        expected_results = [['dblesson0', 'student'], 
                            ['ROOT', 'DBLesson'],
                            ['ROOT', 'student'],
                            ['dblesson0', 'dow'], 
                            ['ROOT', 'dow']] 
        
        
        expected_results.sort()
        results.sort()
        
        self.assertListEqual(expected_results,results)
        
    
    def test_filter_objtype_field_filters(self):
        from types import StringType,IntType, UnicodeType
        
        expected_results = [['ROOT', '-', '-', 'student'], 
                            ['ROOT', 'Amelia', 'Booker', 'DBLesson'], 
                            ['dblesson0', '-', '-', 'student']]


        results = self.of.dumpobjrpt(objtypes=['DBLesson','student'],
                                     objref=False,
                                     fields=['teacher','student'])
         
        
        expected_results.sort()
        results.sort()
        
        self.assertListEqual(expected_results,results)
        

    def test_all_fields(self):
        from types import StringType,IntType, UnicodeType
        
        expected_results = [['ROOT', 'dm:teacher=Amelia', 
                             'dm:session=AM.AC.SC', 'dm:student=Booker', 
                             'dm:period=830', 'dm:saveversion=0', 
                             'dm:dow=MO', 'dm:dm:dow=MO', 'dow:MO', 
                             'objid:dblesson0', 'period:830', 
                             'recursion:True', 'saveversion:0', 
                             'session:AM.AC.SC', 'student:Booker', 
                             'teacher:Amelia']]

        results = self.of.dumpobjrpt(objtypes=['DBLesson'],
                                     objref=False,
                                     fields=['all'],
                                     omitfields=['id'],
                                     fieldnames=True)
         
        
        expected_results.sort()
        results.sort()

        self.assertListEqual(expected_results,results)
        

        
    def tearDown(self):
        self.of.reset()