def parse_sorting_policy(buf):
    # stdout now needs to be parsed into a hash of state => action, which is then sent to mapagent
    p = {}
    stateactions = buf.split("\n")
    for stateaction in stateactions:
        temp = stateaction.split(" = ")
        if len(temp) < 2: continue
        state = temp[0]
        action = temp[1]

        state = state[1:len(state) - 1]
        pieces = state.split(",")

        ss = sortingState(int(pieces[0]), int(pieces[1]), int(pieces[2]),
                          int(pieces[3]))

        if action == "InspectAfterPicking":
            act = InspectAfterPicking()
        elif action == "InspectWithoutPicking":
            act = InspectWithoutPicking()
        elif action == "Pick":
            act = Pick()
        elif action == "PlaceOnConveyor":
            act = PlaceOnConveyor()
        elif action == "PlaceInBin":
            act = PlaceInBin()
        elif action == "ClaimNewOnion":
            act = ClaimNewOnion()
        elif action == "ClaimNextInList":
            act = ClaimNextInList()
        elif action == "PlaceInBinClaimNextInList":
            act = PlaceInBinClaimNextInList()
        else:
            print("Invalid input policy to parse_sorting_policy")
            exit(0)

        p[ss] = act
        # print("parsed ss {} a {}".format(ss,act))

    from mdp.agent import MapAgent
    return MapAgent(p)
def saveDataForBaseline():

    #############################################################
    # BIRL input data for checking if problem is method
    #############################################################

    sortingMDP = model
    for s in sortingMDP.S():
        dummy_states.append(s)
    dummy_states.append(sortingState(-1, -1, -1, -1))

    ind = 0
    for s in dummy_states:
        ind = ind + 1
        dict_stateEnum[ind] = s
    print("dict_stateEnum \n", dict_stateEnum)

    acts = [InspectAfterPicking(),PlaceOnConveyor(),PlaceInBin(),\
    Pick(),ClaimNewOnion(),InspectWithoutPicking(),ClaimNextInList()]
    ind = 0
    for a in acts:
        ind = ind + 1
        dict_actEnum[ind] = a

    # record first trajectory in data for single task BIRL
    enumerateForBIRLsortingModel1(traj)

    f_st_BIRLcode.close()
    f_ac_BIRLcode.close()

    f_TM_BIRLcode = open(get_home() + "/BIRL_MLIRL_data/transition_matrix.txt",
                         "w")
    f_TM_BIRLcode.write("")
    f_TM_BIRLcode.close()
    tuple_res = sortingMDP.generate_matrix(dict_stateEnum, dict_actEnum)
    dict_tr = tuple_res[0]
    f_TM_BIRLcode = open(get_home() + "/BIRL_MLIRL_data/transition_matrix.txt",
                         "a")
    for ind1 in range(1, len(dict_actEnum) + 1):
        acArray2d = np.empty((len(dict_stateEnum), len(dict_stateEnum)))

        for ind2 in range(1, len(dict_stateEnum) + 1):
            for ind3 in range(1, len(dict_stateEnum) + 1):
                acArray2d[ind3 - 1][ind2 - 1] = dict_tr[ind1][ind3][ind2]

        for ind3 in range(1, len(dict_stateEnum) + 1):
            for ind2 in range(1, len(dict_stateEnum) + 1):
                f_TM_BIRLcode.write(str(acArray2d[ind3 - 1][ind2 - 1]) + ",")
            f_TM_BIRLcode.write("\n")
        f_TM_BIRLcode.write("\n")

    f_TM_BIRLcode.close()

    f_Phis_BIRLcode = open(get_home() + "/BIRL_MLIRL_data/features_matrix.txt",
                           "w")
    f_Phis_BIRLcode.write("")
    f_Phis_BIRLcode.close()
    f_Phis_BIRLcode = open(get_home() + "/BIRL_MLIRL_data/features_matrix.txt",
                           "a")
    for inda in range(1, len(dict_actEnum) + 1):
        a = dict_actEnum[inda]
        for inds in range(1, len(dict_stateEnum) + 1):
            s = dict_stateEnum[inds]
            arraysPhis = sortingReward.features(s, a)
            for indk in range(1, len(arraysPhis) + 1):
                f_Phis_BIRLcode.write(str(arraysPhis[indk - 1]) + ",")
            f_Phis_BIRLcode.write("\n")
        f_Phis_BIRLcode.write("\n")
    f_Phis_BIRLcode.close()

    wts_experts_array = np.empty(
        (sortingReward._dim, len(np.unique(true_assignments))))
    j = 0
    for wt_ind in np.unique(true_assignments):
        for i in range(0, wts_experts_array.shape[0]):
            wts_experts_array[i][j] = List_TrueWeights[wt_ind][i]
        j += 1

    f_wts_BIRLcode = open(get_home() + "/BIRL_MLIRL_data/weights_experts.log",
                          "w")
    f_wts_BIRLcode.write("")
    f_wts_BIRLcode.close()
    f_wts_BIRLcode = open(get_home() + "/BIRL_MLIRL_data/weights_experts.log",
                          "a")
    for i in range(0, wts_experts_array.shape[0]):
        for e in range(0, wts_experts_array.shape[1]):
            f_wts_BIRLcode.write(str(wts_experts_array[i][e]) + ",")
        f_wts_BIRLcode.write("\n")
    f_wts_BIRLcode.close()
def parsePolicies(stdout, lineFoundWeights, lineFeatureExpec, \
 learned_weights, num_Trajsofar, BatchIRLflag):

    if stdout is None:
        print("no stdout in parse policies")

    stateactions = stdout.split("\n")
    #print("\n parse Policies from contents:")
    #print(stateactions)
    counter = 0
    p = {}
    for stateaction in stateactions:
        counter += 1
        if stateaction == "ENDPOLICY":
            break
        temp = stateaction.split(" = ")
        if len(temp) < 2: continue
        state = temp[0]
        action = temp[1]

        state = state[1:len(state) - 1]
        pieces = state.split(",")
        ss = sortingState(int(pieces[0]), int(pieces[1]), int(pieces[2]),
                          int(pieces[3]))
        # print((state,pieces,ss))

        if action == "InspectAfterPicking":
            act = InspectAfterPicking()
        elif action == "InspectWithoutPicking":
            act = InspectWithoutPicking()
        elif action == "Pick":
            act = Pick()
        elif action == "PlaceOnConveyor":
            act = PlaceOnConveyor()
        elif action == "PlaceInBin":
            act = PlaceInBin()
        elif action == "ClaimNewOnion":
            act = ClaimNewOnion()
        elif action == "ClaimNextInList":
            act = ClaimNextInList()
        elif action == "PlaceInBinClaimNextInList":
            act = PlaceInBinClaimNextInList()
        else:
            print("Invalid input policy to parse_sorting_policy")
            exit(0)

        p[ss] = act

    returnval = [mdp.agent.MapAgent(p)]

    sessionFinish = True
    if len(stateactions[counter:]) > 0 and BatchIRLflag == False:
        # this change is not reflected in updatewithalg

        sessionFinish = True
        # print("\n sessionFinish = True")#results after i2rl session at time: "+str(rospy.Time.now().to_sec()))
        # file = open("/home/saurabh/patrolstudy/i2rl_troubleshooting/I2RLOPread_rosctrl.txt","r")
        lineFoundWeights = stateactions[counter]
        counter += 1
        global reward_dim

        print(lineFoundWeights[1:-1].split(", "))
        stripped_weights = lineFoundWeights[1:-1].split(", ")

        learned_weights = [float(x) for x in stripped_weights]

        # print("lineFoundWeights:"+lineFoundWeights)
        lineFeatureExpec = stateactions[counter]
        counter += 1

        num_Trajsofar = int(stateactions[counter].split("\n")[0])
        counter += 1

    elif len(stateactions[counter:]) == 0:
        lineFoundWeights = lineFoundWeights
        lineFeatureExpec = lineFeatureExpec
        num_Trajsofar = num_Trajsofar
        sessionFinish = False
        print("\n no results from i2rl session")

    return (returnval, lineFoundWeights, lineFeatureExpec, \
     learned_weights, num_Trajsofar, sessionFinish)
#!/usr/bin/env python 
import rospy 
import sys 

sys.path.append('/home/psuresh/catkin_ws/src/sorting_patrol_MDP_irl')
from sortingMDP.model import sortingModelbyPSuresh4multipleInit_onlyPIP, \
    sortingState, ClaimNewOnion

if __name__ == '__main__':
    rospy.init_node('test_node')
    s_mdp = sortingModelbyPSuresh4multipleInit_onlyPIP()
    s = sortingState(2,2,2,2)
    a = ClaimNewOnion()
    ns = s_mdp.T(s,a).keys()[0]
    print(ns)
    
def main():

    # rospy.init_node('execute_I2RL_policy',anonymous=True, disable_signals=True)
    global pnp
    rate = rospy.Rate(10)

    s_mdp = sortingModelbyPSuresh4multipleInit_onlyPIP()
    print("rospy.init_node('execute_I2RL_policy'), s_mdp ")

    call_service = False
    if call_service:
        nOnionLoc = 5
        nEEFLoc = 4
        nPredict = 3
        nlistIDStatus = 3
        prev_array_pol = [4] * nOnionLoc * nEEFLoc * nPredict * nlistIDStatus

        rospy.wait_for_service('/runRobustIrlGetPolicy')
        try:
            irl_service = rospy.ServiceProxy("/runRobustIrlGetPolicy",
                                             requestPolicy)
            session_index = 0
            num_sessions = 2
            for i in range(num_sessions):
                response = irl_service()

                if len(response.policy) > 0:
                    if response.policy != prev_array_pol:
                        print("I2RL session number ", session_index)
                        # print ('\n policy learned by IRL \n',response.policy)
                        policy_pip = response.policy
                        prev_array_pol = response.policy
                        session_index += 1
                    else:
                        print(
                            "current session did not change policy: either solver broke half-way (a non-reproducible problem) or learning can't be improved further "
                        )
                else:
                    print('\nNo policy response from irl service')

        except rospy.ServiceException as e:
            print("Service call failed: %s" % e)

        # exit(0)
    else:
        policy_pip = np.array([
            4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
            4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4, 4, 4, 4,
            4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 4, 3, 4, 4, 4, 4, 3,
            4, 4, 4, 4, 3, 4, 4, 4, 4, 3, 4, 4, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4,
            4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
            4, 4, 6, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 4, 4, 4, 4, 4,
            4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4, 4, 4, 4, 4, 4,
            4, 4, 4, 4, 4, 4, 3, 4, 4, 4, 4, 3, 4, 4, 4, 4, 3, 4, 4, 4, 4, 4,
            4, 4, 0, 4
        ])
    ''' 
    Solution:
    
    Use fixed initial state instead of usin getState,  
    use action 
    pick the deterministic transition model from sortingMDP to identify next state, 
    pick next action 
    PlaceInBin action takes to [2, 2, 2, 2] from where new onion can be claimed 
    PlaceOnConveyor action takes to [ 0, 2, 0, 2 ] where new onion is already claimed 
    So state machine for placeonconveyor do two things: 
    place object on conveyor and move to next index of bounding box. 

    '''

    # state machines for actions claimnewonion, pick, inspectafterpicking, placeinbin, placeonconveyor
    sm_claimnewonion = StateMachine(
        outcomes=['TIMED_OUT', 'SUCCEEDED', 'SORTCOMPLETE'])
    sm_claimnewonion.userdata.sm_x = []
    sm_claimnewonion.userdata.sm_y = []
    sm_claimnewonion.userdata.sm_z = []
    sm_claimnewonion.userdata.sm_color = []
    sm_claimnewonion.userdata.sm_counter = 0
    with sm_claimnewonion:
        StateMachine.add('GETINFO',
                         Get_info(),
                         transitions={
                             'updated': 'CLAIM',
                             'not_updated': 'GETINFO',
                             'timed_out': 'TIMED_OUT',
                             'completed': 'SORTCOMPLETE'
                         },
                         remapping={
                             'x': 'sm_x',
                             'y': 'sm_y',
                             'z': 'sm_z',
                             'color': 'sm_color',
                             'counter': 'sm_counter'
                         })
        StateMachine.add('CLAIM',
                         Claim(),
                         transitions={
                             'updated': 'SUCCEEDED',
                             'not_updated': 'CLAIM',
                             'timed_out': 'TIMED_OUT',
                             'not_found': 'GETINFO',
                             'completed': 'SORTCOMPLETE'
                         },
                         remapping={
                             'x': 'sm_x',
                             'y': 'sm_y',
                             'z': 'sm_z',
                             'color': 'sm_color',
                             'counter': 'sm_counter'
                         })

    print("Hey I got here!")

    sm_pick = StateMachine(outcomes=['TIMED_OUT', 'SUCCEEDED'])
    sm_pick.userdata.sm_x = []
    sm_pick.userdata.sm_y = []
    sm_pick.userdata.sm_z = []
    sm_pick.userdata.sm_color = []
    sm_pick.userdata.sm_counter = 0
    with sm_pick:
        StateMachine.add('APPROACH',
                         Approach(),
                         transitions={
                             'success': 'PICK',
                             'failed': 'APPROACH',
                             'timed_out': 'TIMED_OUT',
                             'not_found': 'TIMED_OUT'
                         },
                         remapping={
                             'x': 'sm_x',
                             'y': 'sm_y',
                             'z': 'sm_z',
                             'color': 'sm_color',
                             'counter': 'sm_counter'
                         })
        StateMachine.add('PICK',
                         PickSM(),
                         transitions={
                             'success': 'GRASP',
                             'failed': 'PICK',
                             'timed_out': 'TIMED_OUT',
                             'not_found': 'TIMED_OUT'
                         },
                         remapping={
                             'x': 'sm_x',
                             'y': 'sm_y',
                             'z': 'sm_z',
                             'color': 'sm_color',
                             'counter': 'sm_counter'
                         })

        StateMachine.add('GRASP',
                         Grasp_object(),
                         transitions={
                             'success': 'LIFTUP',
                             'failed': 'GRASP',
                             'timed_out': 'TIMED_OUT',
                             'not_found': 'TIMED_OUT'
                         },
                         remapping={
                             'x': 'sm_x',
                             'y': 'sm_y',
                             'z': 'sm_z',
                             'color': 'sm_color',
                             'counter': 'sm_counter'
                         })
        StateMachine.add('LIFTUP',
                         Liftup(),
                         transitions={
                             'success': 'SUCCEEDED',
                             'failed': 'LIFTUP',
                             'timed_out': 'TIMED_OUT'
                         },
                         remapping={'counter': 'sm_counter'})

    sm_inspectafterpicking = StateMachine(outcomes=['TIMED_OUT', 'SUCCEEDED'])
    sm_inspectafterpicking.userdata.sm_x = []
    sm_inspectafterpicking.userdata.sm_y = []
    sm_inspectafterpicking.userdata.sm_z = []
    sm_inspectafterpicking.userdata.sm_color = []
    sm_inspectafterpicking.userdata.sm_counter = 0
    with sm_inspectafterpicking:
        StateMachine.add('VIEW',
                         View(),
                         transitions={
                             'success': 'SUCCEEDED',
                             'failed': 'VIEW',
                             'timed_out': 'TIMED_OUT'
                         },
                         remapping={'counter': 'sm_counter'})

    sm_placeinbin = StateMachine(outcomes=['TIMED_OUT', 'SUCCEEDED'])
    sm_placeinbin.userdata.sm_x = []
    sm_placeinbin.userdata.sm_y = []
    sm_placeinbin.userdata.sm_z = []
    sm_placeinbin.userdata.sm_color = []
    sm_placeinbin.userdata.sm_counter = 0
    with sm_placeinbin:
        StateMachine.add('PLACEINBIN',
                         PlaceInBinSM(),
                         transitions={
                             'success': 'DETACH',
                             'failed': 'PLACEINBIN',
                             'timed_out': 'TIMED_OUT'
                         },
                         remapping={
                             'color': 'sm_color',
                             'counter': 'sm_counter'
                         })
        StateMachine.add('DETACH',
                         Detach_object_wo_ClaimNew(),
                         transitions={
                             'success': 'SUCCEEDED',
                             'failed': 'DETACH',
                             'timed_out': 'TIMED_OUT'
                         },
                         remapping={
                             'x': 'sm_x',
                             'y': 'sm_y',
                             'z': 'sm_z',
                             'color': 'sm_color',
                             'counter': 'sm_counter'
                         })

    sm_placeonconveyor = StateMachine(
        outcomes=['TIMED_OUT', 'SUCCEEDED', 'SORTCOMPLETE'])
    sm_placeonconveyor.userdata.sm_x = []
    sm_placeonconveyor.userdata.sm_y = []
    sm_placeonconveyor.userdata.sm_z = []
    sm_placeonconveyor.userdata.sm_color = []
    sm_placeonconveyor.userdata.sm_counter = 0
    with sm_placeonconveyor:
        StateMachine.add('PLACEONCONVEYOR',
                         PlaceOnConveyorSM(),
                         transitions={
                             'success': 'DETACH',
                             'failed': 'PLACEONCONVEYOR',
                             'timed_out': 'TIMED_OUT'
                         },
                         remapping={
                             'color': 'sm_color',
                             'counter': 'sm_counter'
                         })
        StateMachine.add('DETACH',
                         Detach_object(),
                         transitions={
                             'success': 'LIFTUP',
                             'failed': 'DETACH',
                             'timed_out': 'TIMED_OUT',
                             'completed': 'SORTCOMPLETE'
                         },
                         remapping={
                             'x': 'sm_x',
                             'y': 'sm_y',
                             'z': 'sm_z',
                             'color': 'sm_color',
                             'counter': 'sm_counter'
                         })
        StateMachine.add('LIFTUP',
                         Liftup(),
                         transitions={
                             'success': 'SUCCEEDED',
                             'failed': 'LIFTUP',
                             'timed_out': 'TIMED_OUT'
                         },
                         remapping={'counter': 'sm_counter'})

    aid2act = {
        0: 'InspectAfterPicking',
        1: 'PlaceOnConveyor',
        2: 'PlaceInBin',
        3: 'Pick',
        4: 'ClaimNewOnion',
        5: 'InspectWithoutPicking',
        6: 'ClaimNextInList'
    }

    act2sm = {
        'InspectAfterPicking': sm_inspectafterpicking,
        'PlaceOnConveyor': sm_placeonconveyor,
        'PlaceInBin': sm_placeinbin,
        'Pick': sm_pick,
        'ClaimNewOnion': sm_claimnewonion
    }

    prev_sm = sm_claimnewonion
    s = sortingState(2, 2, 2, 2)

    while not rospy.is_shutdown():
        print("current state ", s)

        # use policy to pick current action and SM
        sid = vals2sid(s._onion_location, s._EE_location, s._prediction,
                       s._listIDs_status)
        aid = policy_pip[sid]
        current_action = aid2act[aid]
        print("current action ", current_action)

        # current state machine based on action, and pass userdata to next sm
        current_sm = act2sm[current_action]
        current_sm.userdata = prev_sm.userdata
        current_sm.userdata.sm_counter = 0

        # execute action
        outcome = current_sm.execute()
        print("outcome ", outcome)
        prev_sm = current_sm

        # exit(0)
        if outcome == 'SORTCOMPLETE':
            exit(0)
        elif outcome == 'SUCCEEDED':
            # next state
            if current_action == 'ClaimNewOnion':
                a = ClaimNewOnion()
            elif current_action == 'PlaceOnConveyor':
                a = PlaceOnConveyor()
            elif current_action == 'PlaceInBin':
                a = PlaceInBin()
            elif current_action == 'Pick':
                a = Pick()
            elif current_action == 'InspectAfterPicking':
                a = InspectAfterPicking()
            else:
                print("Unepxected action chosen action.")

            for ns in s_mdp.T(s, a).keys():
                if ns != s:
                    s = ns
                    print("ns!=s next state is ", s)
                    break

            if current_action == 'InspectAfterPicking':
                print("prediction ",
                      int(prev_sm.userdata.sm_color[pnp.onion_index]))
                s._prediction = int(prev_sm.userdata.sm_color[pnp.onion_index])

        rate.sleep()