def main():

    num_files = 8924  #8924

    max_coef_exp = 1000
    max_coef_slope = 5  # +- that %pwc in one day!!

    tot_num_transitions = 00  # across all users
    list_diff_states = []
    list_all_transitions = []

    file_name_lin = "temporal_series/most_weigh_ins/all_cuts/scatter_plot_parameters_lin.dat"
    file_lin = open(file_name_lin, 'wt')

    file_name_exp = "temporal_series/most_weigh_ins/all_cuts/scatter_plot_parameters_exp.dat"
    file_exp = open(file_name_exp, 'wt')

    for index_file in range(num_files):
        index_file += 1

        try:

            file_name = "temporal_series/most_weigh_ins/all_cuts/weigh_in_time_serie_days" + str(
                index_file) + "_filters.cuts"

            #file_name="temporal_series/most_weigh_ins/test_file_to_create_transition_matrix_"+str(index_file)
            file = open(file_name, 'r')
            list_lines_file = file.readlines()

            dir = file_name.split("weigh_in_")[0]

            ################################
            # lin: intersection + X*slope  #
            #exponential: A+B*exp(lambda*t)#
            ################################

            cont_lines = 1

            for line in list_lines_file:

                list_one_line = line.split(" ")

                trend = list_one_line[0]  # i characterize the current trend:

                if trend == "exp":
                    coef1_exp = round(float(list_one_line[1]),
                                      2)  # we dont care about the interception
                    coef2_exp = round(float(list_one_line[2]), 2)
                    coef3_exp = round(float(list_one_line[3]), 2)

                    if coef1_exp <= max_coef_exp and coef1_exp >= -max_coef_exp:
                        if coef2_exp <= max_coef_exp and coef2_exp >= -max_coef_exp:
                            if coef3_exp <= max_coef_exp and coef3_exp >= -max_coef_exp:

                                if coef2_exp < 0:  # i only care about linear increase/decrease
                                    new_coef2_exp = -1
                                else:
                                    new_coef2_exp = 1

                                if coef3_exp < 0:
                                    new_coef3_exp = -1
                                else:
                                    new_coef3_exp = 1

                                length = float(list_one_line[4])
                                state = trend + "_" + str(
                                    new_coef2_exp) + "_" + str(
                                        new_coef3_exp)  #+"_"+str(coef3_exp)
                                if state not in list_diff_states:
                                    list_diff_states.append(state)
                                    #print "exp",coef1_exp,coef2_exp,coef3_exp,length, state

                                print >> file_exp, coef1_exp, coef2_exp, coef3_exp
                    else:
                        print "bad exp_coef in file", index_file, coef1_exp, coef2_exp, coef3_exp

                elif trend == "lin":
                    coef1_lin = round(float(list_one_line[1]),
                                      2)  # we dont care about the interception
                    coef2_lin = round(float(list_one_line[2]), 2)
                    length = float(list_one_line[3])

                    if coef2_lin <= max_coef_slope and coef2_lin >= -max_coef_slope:

                        if coef2_lin < 0:  # i only care about linear increase/decrease
                            new_coef2_lin = -1
                        else:
                            new_coef2_lin = 1

                        state = trend + "_" + str(
                            new_coef2_lin)  #+"_"+str(coef2_lin)
                        if state not in list_diff_states:
                            list_diff_states.append(state)

                        #print "lin",coef2_lin,length
                        print >> file_lin, coef1_lin, coef2_lin

                    else:
                        print "bad lin_coef in file", index_file, coef1_lin, coef2_lin

                if len(
                        list_lines_file
                ) == 1:  # if only one line in the file --> only one state --> diagonal term for the matrix
                    list = []
                    list.append(state)
                    list.append(state)
                    list_all_transitions.append(list)
                    tot_num_transitions += 1.0

                else:  # if more than one trend in the file
                    if cont_lines == 1:
                        old_state = state
                    else:
                        new_state = state
                        list = []
                        list.append(old_state)
                        list.append(new_state)
                        list_all_transitions.append(list)

                        old_state = state  # i update for the next transition
                        tot_num_transitions += 1.0

                cont_lines += 1
        except IOError:
            print index_file, "file doesnt exist"

    list_diff_states = sorted(
        list_diff_states
    )  # this is the order of the trends for the transition matrix

    print "#num of diff states:", len(
        list_diff_states)  #, "namely:",list_diff_states

    file_lin.close()  #for the scatter plot of the parameters
    file_exp.close()

    ########################i create the empty transition matrix
    matrix = []
    for i in range(len(list_diff_states)):
        matrix.append(
            [0.000] * len(list_diff_states)
        )  #ojo!!!!!!!!!!!!!!! si la creo: matrix.append(list), con list=[0.]*len(list_diff_states), al modificar unos elementos, modificare otros sin querer!!!!!!!!!!!!!!!!!!!

##################################

    print "num transitions:", len(
        list_all_transitions
    )  #,"  list transitions:",list_all_transitions ,"\n"  #ej.: [['exp_1.0_1.0_2.0', 'lin_3.0_1.0'], ['lin_3.0_1.0', 'lin_2.0_1.0'], ['exp_2.0_1.0_3.0', 'exp_2.0_2.0_1.0'], ...]

    cont_exp_to_lin = 0
    cont_lin_to_exp = 0
    for transition in list_all_transitions:
        # print "\n\n looking for:", transition[0],transition[1]
        for i in range(len(list_diff_states)):
            old_state = list_diff_states[i]
            for j in range(len(list_diff_states)):

                new_state = list_diff_states[j]

                if transition[0] == old_state and transition[1] == new_state:
                    matrix[i][j] += 1.000
                    #print "  ",transition, matrix[i][j]  #, list_diff_states[i], list_diff_states[j]
                    if old_state == 'exp_1_-1' and new_state == 'lin_1':
                        #print  old_state, new_state
                        cont_exp_to_lin += 1
                    elif old_state == 'lin_1' and new_state == 'exp_1_-1':
                        #print  old_state, new_state
                        cont_lin_to_exp += 1

    print "old_state == 'exp_1_-1' and new_state== 'lin_1':", cont_exp_to_lin, "old_state == 'lin_1' and new_state== 'exp_1_-1':", cont_lin_to_exp
    file_name1 = "temporal_series/most_weigh_ins/cuts/transiton_probability_matrix_test.dat"
    file1 = open(file_name1, 'wt')

    #print "\n\ntransition matrix:\n"
    #print "\t\t\t",

    # for item in list_diff_states:   # print the header
    #    print item," ",

    #print  "\n"

    for i in range(len(list_diff_states)):
        if "exp" in list_diff_states[i]:
            pass
            #print list_diff_states[i],"\t",
        else:
            pass  #print list_diff_states[i],"\t\t",
        for j in range(len(list_diff_states)):
            # print round(matrix[i][j]/tot_num_transitions,3),"             ",
            print >> file1, round(matrix[i][j] / tot_num_transitions, 3),

    # print "\n"
        print >> file1, "\n"

    file1.close()

    #### PLOTGRACE IMPLEMENTATION ###
    # Draw the matrix

    from PlotGrace import plot_matrix  # i use a function from the module
    # and just modify some of the default options

    plot_matrix(
        matrix,
        dir + "transition_prob_matrix.agr",
        xticklabels=list_diff_states,
        yticklabels=list_diff_states,
        xticklabelangle=90,
        colorscheme='YlOrRd',
        #logcolorbar=True,
        reversecolorscheme=False,
        mincolor=(255, 255, 255),
        title="Transition Probabilities")
Ejemplo n.º 2
0
#!/usr/bin/env python

import sys
from PlotGrace import plot_matrix

outfilename = "matrix.agr"
colorscheme = "YlOrRd"
if len(sys.argv) >= 2:
    outfilename = sys.argv[1]
    if len(sys.argv) >= 3:
        colorscheme = sys.argv[2]

# INPUT MATRIX
mx = []
for rowline in sys.stdin:
    row = map(float, rowline.split())
    mx.append(row)

# format check
num_rows = len(mx)
for row in mx:
    num_cols = len(row)
    if num_rows != num_cols:
        print >> sys.stderr, "Warning! n_rows != n_cols"

# inform size
size = len(mx)
print >> sys.stderr, "<plot_matrix> Network size: ", size

plot_matrix(mx, outfilename, colorscheme=colorscheme, reversecolorscheme=True)
def main():

    max_coef_exp = 1000
    max_coef_slope = 5  # +- that %pwc in one day!!

    tot_num_transitions = 0  # across all users
    list_diff_states = []  # for the transition matrix
    list_all_transitions = []  # for the transition matrix

    database = "calorie_king_social_networking_2010"
    server = "tarraco.chem-eng.northwestern.edu"
    user = "******"
    passwd = "n1ckuDB!"
    db = Connection(server, database, user, passwd)

    query = """select * from users"""
    result1 = db.query(query)  # is a list of dictionaries

    contador_time_series = 0
    contador = 0

    for r1 in result1:  #loop over users
        contador += 1

        ck_id = r1['ck_id']

        list_start_end_dates_one_user = []  #without considering the gaps

        list_states_one_user = [
        ]  #without considering the gaps (just diff. behaviors)

        dicc_states_one_user = {
        }  # the key is the initial_day, the value is the type of trend  or gap
        temp_dicc_states = {}

        try:  #only for the users with an entry in that time_series related table: one of the (filtered) time series

            ######################################
            # for the info about different trends# (for a given, fixed user)
            ######################################

            contador_time_series += 1
            print "\n\n\n", contador, ck_id

            query2 = "select  * from weigh_in_cuts where (ck_id ='" + str(
                ck_id) + "')  order by start_day asc"
            result2 = db.query(query2)

            #result2: [{'fit_type': u'exp', 'stop_idx': 4L, 'ck_id': u'a48eba6e-51ad-42bc-b367-a24bb6504f0a', 'id': 68008L, 'cost': 0.5, 'param3': -0.60660499999999995, 'param2': 3.4280400000000002, 'param1': -3.1829299999999998, 'start_idx': 0L}, {'fit_type': u'lin', 'stop_idx': 16L, 'ck_id': u'a48eba6e-51ad-42bc-b367-a24bb6504f0a', 'id': 68007L, 'cost': 0.5, 'param3': None, 'param2': 0.015299699999999999, 'param1': -4.1967499999999998, 'start_idx': 5L}]

            num_trends = len(result2)

            for r2 in result2:

                start_stop_days = []  # for that particular segment

                starting_day = int(r2['start_day'])
                ending_day = int(r2['stop_day'])
                trend = r2['fit_type']
                param1 = r2['param1']
                param2 = r2['param2']
                param3 = r2['param3']

                start_stop_days.append(starting_day)
                start_stop_days.append(ending_day)

                list_start_end_dates_one_user.append(start_stop_days)

                if trend == "exp":
                    #i dont really care about the interception
                    coef1_exp = param1
                    coef2_exp = param2
                    coef3_exp = param3  ################################
                    # lin: intersection + X*slope  #
                    #exponential: A+B*exp(lambda*t)#
                    ################################

                    if coef1_exp <= max_coef_exp and coef1_exp >= -max_coef_exp:  # to avoid weird fits
                        if coef2_exp <= max_coef_exp and coef2_exp >= -max_coef_exp:
                            if coef3_exp <= max_coef_exp and coef3_exp >= -max_coef_exp:

                                if coef2_exp < 0 and coef3_exp > 0:  # i only care about exp increase/decrease
                                    state = "exp_down"
                                elif coef2_exp > 0 and coef3_exp < 0:
                                    state = "exp_down"
                                elif coef2_exp > 0 and coef3_exp > 0:
                                    state = "exp_up"
                                elif coef2_exp < 0 and coef3_exp < 0:
                                    state = "exp_up"

                            # elif coef2_exp==0.0 or coef2_exp== -0.0:
                            #    state="flat"
                                else:
                                    print ck_id, coef2_exp, coef3_exp, "not one of the two exp types"
                                    raw_input()

                    else:
                        print "values for the fit coef. too weird!", param1, param2, param3, ck_id

                elif trend == "lin":
                    # i dont really care about the interception
                    coef2_lin = param2

                    if coef2_lin <= max_coef_slope and coef2_lin >= -max_coef_slope:  # to avoid weird fits

                        if coef2_lin < 0:  # i only care about linear increase/decrease
                            state = "lin_down"

                            if coef2_lin > -0.0001:
                                state = "flat"

                        elif coef2_lin == 0.0:
                            state = "flat"
                        else:
                            state = "lin_up"
                            if coef2_lin < 0.0001:
                                state = "flat"

                list_states_one_user.append(state)
                temp_dicc_states[
                    starting_day] = state  # i save the pair starting_day, trend for that user
                if state not in list_diff_states:
                    list_diff_states.append(state)

            ####################### end loop over result2  (info diff trends)

        except MySQLdb.Error:
            pass  #for the users without an entry in that time_series related table: not one of the (filtered) time series

        ##########################
        # for the gap info       # (for the same given, fixed user  ck_id)
        ##########################

        query3 = "select  * from gaps_by_frequency where (ck_id ='" + str(
            ck_id) + "')  order by start_day asc"
        result3 = db.query(query3)

        #result3: [{'file_index': 1408, 'ck_id': 8647c765-e37e-4024-92da-be838b792379, 'start_date':2009-05-07 00:00:00, 'end_date':2009-06-08 00:00:00, 'start_day': 108, 'end_day': 140, 'days_gap': 32, 'zscore_gap': 3.83125},{'file_index': 1408, 'ck_id': 8647c765-e37e-4024-92da-be838b792379, 'start_date':2009-08-04 00:00:00, 'end_date':2009-09-30 00:00:00, 'start_day': 197, 'end_day': 254, 'days_gap': 57, 'zscore_gap': 7.1496} ]

        num_gaps = len(result3)

        if num_gaps > 0:

            for r3 in result3:

                file_index = r3['file_index']

                starting_gap = int(r3['start_day'])
                ending_gap = int(r3['end_day'])
                trend = "gap"
                zscore_gap = r3[
                    'zscore_gap']  # threshold to consider a gap statistically sifnificant  zs>=3  (imposed like that in: analyze_frequency_gaps_in_time_series_frequencies_EDIT_DB.py)

                if trend not in list_diff_states:
                    list_diff_states.append(trend)

                cont = -1  #to go over list_states_one_user
                for segment in list_start_end_dates_one_user:  #  the list is sorted chronologically
                    cont += 1

                    start_behavior = segment[0]
                    end_behavior = segment[1]

                    if (starting_gap <= end_behavior) and (
                            ending_gap >= start_behavior):  # if there is a gap
                        # in the middle of a behavior
                        num_trends += 2
                        ## i cut the trend in two segments, with a gap in between
                        new_segment1 = [start_behavior, starting_gap]
                        new_segment2 = [starting_gap, ending_gap]
                        new_segment3 = [ending_gap, end_behavior]

                        temp_dicc_states[
                            start_behavior] = list_states_one_user[cont]
                        temp_dicc_states[starting_gap] = trend
                        temp_dicc_states[ending_gap] = list_states_one_user[
                            cont]

                        #print "gap cutting in the middle of a single behavior:",new_segment1,new_segment2,new_segment3, ck_id

                for i in range(
                        len(list_start_end_dates_one_user)
                ):  # i check whether there is a gap in between DIFF behaviors
                    try:

                        old_ending = list_start_end_dates_one_user[i][1]
                        new_beginnig = list_start_end_dates_one_user[i + 1][0]

                        if (starting_gap >= old_ending) and (ending_gap <=
                                                             new_beginnig):

                            # print "gap in the middle of two diff. behaviors:",list_states_one_user[i],trend,list_states_one_user[i+1]

                            temp_dicc_states[starting_gap] = trend
                            temp_dicc_states[
                                ending_gap] = list_states_one_user[i + 1]

                    except IndexError:
                        pass
                    #print "no room for any more gaps,",len(list_start_end_dates_one_user),i

                # create a final list of all states, and then go for state in list_states: and copy code from aux, line 70 on

            # end loop over result3   (gap info)

        else:  #if the gap info doesnt change the number of trends (NO gaps)
            pass

        for key in temp_dicc_states.keys(
        ):  # i make a copy of the dicc, to be the final one
            dicc_states_one_user[key] = temp_dicc_states[key]

# i account for all possible states and transsitions between states:

        if len(dicc_states_one_user) > 1:  # several trends for the time series

            cont = 0
            for key in sorted(dicc_states_one_user.keys(
            )):  # i print out the result of combining weigh_in cuts and gaps:
                print key, dicc_states_one_user[key]
                state = dicc_states_one_user[key]
                if state not in list_diff_states:
                    list_diff_states.append(state)

                list = []
                if cont == 0:  # for the first state
                    state1 = dicc_states_one_user[key]
                else:  # for all the rest of states in the sorted dictionary
                    state2 = dicc_states_one_user[key]
                    list.append(state1)
                    list.append(state2)
                    list_all_transitions.append(list)

                    state1 = state2
                cont += 1

        elif len(dicc_states_one_user
                 ) == 1:  # one single trend for the time series
            list = []
            for key in dicc_states_one_user:
                list.append(dicc_states_one_user[key])
                list.append(dicc_states_one_user[key])
                print key, dicc_states_one_user[key]
                if state not in list_diff_states:
                    list_diff_states.append(state)

            list_all_transitions.append(list)

        #print  list_all_transitions

#       if num_gaps>0:
#          raw_input()

################################# end loop over result1  (loop over users)

#  i   create the empty transition matrix

    list_diff_states = sorted(
        list_diff_states
    )  # this is the order of the trends for the transition matrix

    matrix = []
    for i in range(len(list_diff_states)):
        matrix.append(
            [0.000] * len(list_diff_states)
        )  #ojo!!!!!!!!!!!!!!! si la creo: matrix.append(list), con list=[0.]*len(list_diff_states), al modificar unos elementos, modificare otros sin querer!!!!!!!!!!!!!!!!!!!

    print "num transitions:", len(
        list_all_transitions), "diff. states:", list_diff_states

    for transition in list_all_transitions:
        for i in range(len(list_diff_states)):
            old_state = list_diff_states[i]
            for j in range(len(list_diff_states)):

                new_state = list_diff_states[j]

                if transition[0] == old_state and transition[1] == new_state:
                    matrix[i][j] += 1.000

    file_name1 = "temporal_series/transiton_probability_matrix_test_with_gap_info.dat"
    file1 = open(file_name1, 'wt')

    for i in range(len(list_diff_states)):
        for j in range(len(list_diff_states)):
            print >> file1, round(
                matrix[i][j] / float(len(list_all_transitions)), 3),

        print >> file1, "\n"
    file1.close()

    #### PLOTGRACE IMPLEMENTATION ###
    # Draw the matrix

    from PlotGrace import plot_matrix  # i use a function from the module
    # and just modify some of the default options

    plot_matrix(
        matrix,
        "temporal_series/transition_prob_matrix_with_gaps.agr",
        xticklabels=list_diff_states,
        yticklabels=list_diff_states,
        xticklabelangle=90,
        colorscheme='YlOrRd',
        #logcolorbar=True,
        reversecolorscheme=False,
        mincolor=(255, 255, 255),
        title="Transition Probabilities")
Ejemplo n.º 4
0
def main():

    min_wi = 20  #Filter1:  min number of weigh ins >=
    #  min_timespan=0       # Filter2: min length of the serie

    max_num_users_for_testing = 2000  # this is just while i test the code!!

    database = "calorie_king_social_networking_2010"
    server = "tarraco.chem-eng.northwestern.edu"
    user = "******"
    passwd = "n1ckuDB!"

    db = Connection(server, database, user, passwd)

    query = """select * from users"""
    result1 = db.query(query)  # is a list of dict.

    num_users = len(result1)

    dir = "../Results/"

    list_diff_states = []
    list_all_transitions = []
    tot_num_transitions = 0.
    num_useful_users = 0

    cont_users = 0
    for r1 in result1:  #loop over users to get their number_of_weigh-ins
        ## r1 is a dict.:  {'ck_id': u'bd84dbe2-dd6e-4125-b006-239442df2ff6', 'age': 52L, 'state': u'', 'height': 64L, 'join_date': datetime.datetime(2009, 11, 27, 10, 41, 5), 'is_staff': u'public', 'most_recent_weight': 142.0, 'initial_weight': 144.0, 'id': 1L}

        #   if cont_users <=max_num_users_for_testing:   #COMMENT THIS LINE FOR THE FINAL RUN!!!

        ck_id = r1['ck_id']
        id = r1['id']

        print cont_users, ck_id
        cont_users += 1

        query2 = "select  * from weigh_in_history where (ck_id ='" + str(
            ck_id) + "')  order by on_day asc"
        result2 = db.query(query2)

        r1['num_wi'] = len(
            result2
        )  # i add another key-value to the dict. -->> with this i ALSO modify the list of dict. result!!!

        first = result2[0]['on_day']
        last = result2[-1]['on_day']
        time_span_days = (last - first).days + 1

        if r1['num_wi'] >= min_wi:  #and   int(time_span_days) >= min_timespan:

            dict_start_day_type_segment = {}

            query3 = "select * from weigh_in_cuts  where (ck_id ='" + str(
                ck_id) + "') order by start_day"
            result3 = db.query(query3)  # is a list of dict.

            for r3 in result3:  # each line is a dict, each line is a segment

                fit_type = str(r3['fit_type'])
                state = fit_type

                start_day = int(r3['start_day'])
                stop_day = int(r3['stop_day'])
                start_weight = float(r3['start_weight'])
                stop_weight = float(r3['stop_weight'])

                if fit_type != "isolated":  # i get the dict states_startig days

                    param1 = float(r3['param1'])
                    try:
                        param2 = float(
                            r3['param2']
                        )  # cos  constant dont have a value for this
                    except:
                        pass

                    try:  # cos lin and constant dont have a value for this
                        param3 = float(r3['param3'])
                    except:
                        pass

                    if fit_type == "exponent":
                        if param3 < 0.:
                            if param2 > 0.:
                                state = fit_type + "_desc"
                                dict_start_day_type_segment[start_day] = state
                            else:
                                state = fit_type + "_asc"
                                dict_start_day_type_segment[start_day] = state
                        else:
                            print "we shouldnt have positive taus!", ck_id

                    elif fit_type == "linear":
                        if param2 < 0.:
                            state = fit_type + "_desc"
                            if param2 > -0.001:  # if very small slope, it is actually constant (this around 5% per two month)
                                state = "constant"

                        else:
                            state = fit_type + "_asc"
                            if param2 < 0.001:
                                state = "constant"

                        dict_start_day_type_segment[start_day] = state

                    else:  # constant

                        dict_start_day_type_segment[start_day] = state

            if len(dict_start_day_type_segment
                   ) > 0:  # user with at least one useful segment
                num_useful_users += 1

                query4 = "SELECT * FROM frequency_gaps where (ck_id ='" + str(
                    ck_id) + "') order by start_day"  #gap info
                result4 = db.query(query4)  # is a list of dict.

                for r4 in result4:

                    start_day = int(r4['start_day'])
                    stop_day = int(r4['stop_day'])
                    start_weight = float(r4['start_weight'])
                    stop_weight = float(r4['stop_weight'])

                    dict_start_day_type_segment[start_day] = "gap"

                list_tuples_sorted_dict = sorted(
                    dict_start_day_type_segment.iteritems(),
                    key=operator.itemgetter(0)
                )  # the index of itermgetter is by which i order the list of tuples that was the dictionary
                print ck_id,

                ######### once i get all gaps, i merge consecutive gaps into just one:
                print " states before merging gaps:", list_tuples_sorted_dict, len(
                    list_tuples_sorted_dict)

                flag_first_gap_index = 0
                list_items_to_remove = []
                for item in list_tuples_sorted_dict:
                    # print item
                    if item[1] == 'gap':
                        if flag_first_gap_index == 0:
                            flag_first_gap_index = 1
                        else:
                            list_items_to_remove.append(item)
                            flag_first_gap_index = 1

                    else:
                        flag_first_gap_index = 0  #i dont want to remove the first gap right after a segment

                if len(
                        list_items_to_remove
                ) > 0:  # IT DIDT WORK TO REMOVE THEM AS I GO OVER THE LIST IN THE PREVIOS LOOP
                    #  print  "list items to remove:",list_items_to_remove

                    for item in list_items_to_remove:
                        list_tuples_sorted_dict.remove(item)

                # print "states after merging gaps:", list_tuples_sorted_dict,"\n"
                # raw_input()

                if len(
                        list_tuples_sorted_dict
                ) == 1:  # if only one state --> diagonal term for the matrix

                    state = list_tuples_sorted_dict[0][1]

                    if state not in list_diff_states:  # for the transition matrix
                        list_diff_states.append(state)

                    lista = []
                    lista.append(state)
                    lista.append(state)
                    list_all_transitions.append(lista)
                    tot_num_transitions += 1.0

                else:  # if more than one trend

                    cont = 0
                    for item in list_tuples_sorted_dict:
                        state = item[1]
                        if state not in list_diff_states:  # for the transition matrix
                            list_diff_states.append(state)

                        if cont > 0:
                            new_state = state
                            lista = []
                            lista.append(old_state)
                            lista.append(new_state)
                            list_all_transitions.append(lista)
                            tot_num_transitions += 1.0

                        old_state = state  # i update for the next transition
                        cont += 1

    ################# end loop over users

    list_diff_states = sorted(
        list_diff_states
    )  # this is the order of the trends for the transition matrix

    print "#num of diff states:", len(
        list_diff_states), "namely:", list_diff_states

    print "#num useful users:", num_useful_users

    print "num transitions:", len(
        list_all_transitions
    )  #,"  list transitions:",list_all_transitions ,"\n"  #ej.: [['exp_1.0_1.0_2.0', 'lin_3.0_1.0'], ['lin_3.0_1.0', 'lin_2.0_1.0'], ['exp_2.0_1.0_3.0', 'exp_2.0_2.0_1.0'], ...]

    matrix = []  # raw count of transitions
    norm_matrix = [
    ]  # normalized count of transitios by the tot number of them
    for i in range(len(list_diff_states)):
        matrix.append(
            [0.000] * len(list_diff_states)
        )  #ojo!!!!!!!!!!!!!!! si la creo: matrix.append(list), con list=[0.]*len(list_diff_states), al modificar unos elementos, modificare otros sin querer!!!!!!!!!!!!!!!!!!!
        norm_matrix.append([0.000] * len(list_diff_states))

    for transition in list_all_transitions:
        # print "\n\n looking for:", transition[0],transition[1]
        for i in range(len(list_diff_states)):
            old_state = list_diff_states[i]
            for j in range(len(list_diff_states)):

                new_state = list_diff_states[j]
                #print  old_state, new_state

                if transition[0] == old_state and transition[1] == new_state:
                    matrix[i][j] += 1.000
                    norm_matrix[i][j] += 1.000

    for i in range(len(list_diff_states)):
        for j in range(len(list_diff_states)):
            norm_matrix[i][j] = norm_matrix[i][j] / tot_num_transitions

    file_name1 = dir + "numerical_values_transiton_probability_matrix.dat"
    file1 = open(file_name1, 'wt')

    for i in range(len(list_diff_states)):
        if "exp" in list_diff_states[i]:
            pass
            #print list_diff_states[i],"\t",
        else:
            pass  #print list_diff_states[i],"\t\t",
        for j in range(len(list_diff_states)):
            # print round(matrix[i][j]/tot_num_transitions,3),"             ",
            print >> file1, round(
                matrix[j][i] / tot_num_transitions, 3
            ),  # so the order row/columns matches with the .agr representation

    # print "\n"
        print >> file1, "\n"
    file1.close()

    print "printed out matrix textfile:", dir + "transiton_probability_matrix_test.dat"

    #### PLOTGRACE IMPLEMENTATION ###
    # Draw the matrix

    plot_matrix(
        matrix,
        dir + "transition_prob_matrix.agr",
        xticklabels=list_diff_states,
        yticklabels=list_diff_states,
        xticklabelangle=90,
        colorscheme='YlOrRd',
        #logcolorbar=True,
        xlabel="Initial State",
        ylabel="Final State",
        reversecolorscheme=False,
        mincolor=(255, 255, 255),
        title="Transition Probabilities"
    )  #########Axis for the matrix plot are X: initial state,   Y: final state.

    print "printed out matrix plot:", dir + "transition_prob_matrix.agr"

    plot_matrix(
        norm_matrix,
        dir + "transition_prob_matrix_norm.agr",
        xticklabels=list_diff_states,
        yticklabels=list_diff_states,
        xticklabelangle=90,
        colorscheme='YlOrRd',
        #logcolorbar=True,
        xlabel="Initial State",
        ylabel="Final State",
        reversecolorscheme=False,
        mincolor=(255, 255, 255),
        title="Transition Probabilities"
    )  #########Axis for the matrix plot are X: initial state,   Y: final state.

    print "printed out matrix plot:", dir + "transition_prob_matrix_norm.agr"