Exemplo n.º 1
0
    def __init__(self, s0, a, s1=None, amdp_id=2):
        self.action_type = a.action_type
        self.action_object = a.object
        self.amdp_id = amdp_id

        # convert action into something that fits into the new action list
        if a.action_type == Action.PLACE:
            self.action_object = DataUtils.get_task_frame(s0, a.position)
        elif a.action_type == Action.MOVE_ARM:
            self.action_object = DataUtils.get_task_frame(s0, a.position)
            if (amdp_id <= 2 and self.action_object != 'stack' and self.action_object != 'drawer') or \
                                (amdp_id >= 6 and self.action_object != 'box' and self.action_object != 'lid'):
                for o in s0.objects:
                    if o.name != 'apple':
                        continue
                    if a.position == o.position:
                        self.action_object = 'apple'
                        break
                if self.action_object != 'apple':
                    x = s0.gripper_position.x
                    y = s0.gripper_position.y
                    px = a.position.x
                    py = a.position.y
                    if px == x and py > y:
                        self.action_object = 'b'
                    elif px < x and py > y:
                        self.action_object = 'bl'
                    elif px < x and py == y:
                        self.action_object = 'l'
                    elif px < x and py < y:
                        self.action_object = 'fl'
                    elif px == x and py < y:
                        self.action_object = 'f'
                    elif px > x and py < y:
                        self.action_object = 'fr'
                    elif px > x and py == y:
                        self.action_object = 'r'
                    else:
                        self.action_object = 'br'
        elif a.action_type == Action.GRASP:
            pass
        else:
            self.action_object = ''

        # compute amdp state representation
        s0_prime = AMDPState(amdp_id=self.amdp_id, state=OOState(state=s0))
        if s1 is not None:
            s1_prime = AMDPState(amdp_id=self.amdp_id, state=OOState(state=s1))

        # ********************************  Preconditions  ********************************
        self.preconditions = self.state_to_preconditions(s0_prime)

        # ********************************  Effects  ********************************
        self.effects = {}
        if s1 is not None:
            result = self.state_to_preconditions(s1_prime)
            for key, value in result.iteritems():
                if key in self.preconditions and self.preconditions[
                        key] != value:
                    self.effects[key] = value
Exemplo n.º 2
0
 def check_preconditions(self, state, ground_items=None):
     s = AMDPState(amdp_id=self.amdp_id,
                   state=OOState(state=state),
                   ground_items=ground_items)
     ps = self.state_to_preconditions(s)
     for key, value in self.preconditions.iteritems():
         if not (key in ps and ps[key] == value):
             return False
     return True
Exemplo n.º 3
0
    def query_status(self, req):
        # Check termination criteria
        failed = False
        status = Status()
        status.status_code = Status.IN_PROGRESS
        for object in req.state.objects:
            if object.name.lower() == 'apple':
                dst = sqrt(
                    pow(20 - object.position.x, 2) +
                    pow(1 - object.position.y, 2))
                if object.lost or dst <= 3 or dst >= 20:
                    failed = True
                    break
            if object.name.lower() == 'banana':
                dst = sqrt(
                    pow(20 - object.position.x, 2) +
                    pow(1 - object.position.y, 2))
                if object.lost or dst <= 3 or dst >= 20:
                    failed = True
                    break
            if object.name.lower() == 'carrot':
                dst = sqrt(
                    pow(20 - object.position.x, 2) +
                    pow(1 - object.position.y, 2))
                if object.lost or dst <= 3 or dst >= 20:
                    failed = True
                    break

        if req.state.drawer_opening > 1:
            completed = False
        if req.state.lid_position.x != req.state.box_position.x or req.state.lid_position.y != req.state.box_position.y:
            completed = False

        oo_state = OOState(state=req.state, continuous=self.continuous)
        amdp_id = 12
        s = AMDPState(amdp_id=amdp_id, state=oo_state)
        if is_terminal(s, amdp_id=amdp_id):
            status.status_code = Status.COMPLETED

        if failed:
            status.status_code = Status.FAILED
            return status

        return status
Exemplo n.º 4
0
    def select_action(self, req, debug=1):
        action = Action()

        action_list = []

        oo_state = OOState(state=req.state, continuous=self.continuous)

        if self.complexity > 0:
            # TODO: this is commented out for drawer-only testing!
            # start at the top level
            s = AMDPState(amdp_id=12, state=oo_state)
            utilities = {}
            for a in self.A[12]:
                successors = self.T[t_id_map[12]].transition_function(s, a)
                u = 0
                for i in range(len(successors)):
                    p = successors[i][0]
                    s_prime = successors[i][1]
                    if s_prime in self.U[12]:
                        u += p * self.U[12][s_prime]
                    elif is_terminal(s_prime, amdp_id=12):
                        u += p * reward(s_prime, amdp_id=12)
                utilities[a] = u

            # print '\n---'
            # for key in utilities:
            #     print str(key)
            #     print 'utility: ' + str(utilities[key])

            # pick top action deterministically
            max_utility = -999999
            for a in utilities.keys():
                if utilities[a] > max_utility:
                    max_utility = utilities[a]
                    action_list = []
                    action_list.append(deepcopy(a))
                elif utilities[a] == max_utility:
                    action_list.append(deepcopy(a))

            # select action
            # i = randint(0, len(action_list) - 1)
            i = 0
            id = action_list[i].action_type
            #obj = action_list[i].object

            if debug > 0:
                print 'Top level action selection: ' + str(id)

            s = AMDPState(amdp_id=id, state=oo_state)
            # s = AMDPState(amdp_id=4, state=oo_state)  # TODO: temporary, for drawer-only testing

        else:
            if self.env_type % 2 == 0:
                id = 4
            else:
                id = 11

            s = AMDPState(amdp_id=id,
                          state=oo_state,
                          ground_items=['apple', 'apple', 'apple', 'apple'])

        # TODO: debugging state
        print '\n\n-------------------------------------------------------------'
        print 'Mid-level AMDP state:'
        print str(s)
        print '-------------------------------------------------------------\n\n'

        utilities = {}
        for a in self.A[id]:
            successors = self.T[t_id_map[id]].transition_function(s, a)
            u = 0
            for i in range(len(successors)):
                p = successors[i][0]
                s_prime = successors[i][1]
                if s_prime in self.U[id]:
                    u += p * self.U[id][s_prime]
                elif is_terminal(s_prime, amdp_id=id):
                    u += p * reward(s_prime, amdp_id=id)
            utilities[a] = u

        # print '\n---'
        # for key in utilities:
        #     print str(key)
        #     print 'utility: ' + str(utilities[key])

        # pick top action deterministically
        max_utility = -999999
        for a in utilities.keys():
            if utilities[a] > max_utility:
                max_utility = utilities[a]
                action_list = []
                action_list.append(deepcopy(a))
            elif utilities[a] == max_utility:
                action_list.append(deepcopy(a))

        # select action
        # i = randint(0, len(action_list) - 1)
        i = 0
        id = action_list[i].action_type
        if self.complexity > 0:
            obj = action_list[i].object
        else:
            if action_list[i].object in [
                    'apple', 'banana', 'carrot', 'daikon'
            ]:
                obj = 'apple'
            else:
                obj = action_list[i].object

        if debug > 0:
            print '\tMid level action selection: ' + str(id) + ', ' + str(obj)

        # solve lower level mdp for executable action
        action_list = []
        s = AMDPState(amdp_id=id, state=oo_state, ground_items=[obj])

        # TODO: debugging state
        print '\n\n-------------------------------------------------------------'
        print 'Low-level AMDP state:'
        print str(s)
        print '-------------------------------------------------------------\n\n'

        selected_from_utility = 1

        if self.q_learning_mode:
            action = self.Q[id].select_action(s, action_list=self.A[id])
            if action is None:
                selected_from_utility = 0
                if self.demo_mode.classifier:
                    action = Action()
                    features = s.to_vector()
                    probs = self.classifiers[t_id_map[id]].predict_proba(
                        np.asarray(features).reshape(1,
                                                     -1)).flatten().tolist()
                    selection = random()
                    cprob = 0
                    action_label = '0:apple'
                    for i in range(0, len(probs)):
                        cprob += probs[i]
                        if cprob >= selection:
                            action_label = self.classifiers[
                                t_id_map[id]].classes_[i]
                            break
                    # Convert back to action
                    result = action_label.split(':')
                    action.action_type = int(result[0])
                    if len(result) > 1:
                        action.object = result[1]
                else:
                    action = self.A[id][randint(0, len(self.A[id]) - 1)]
            if action.object == 'apple':
                if obj not in items:
                    action.object = items[randint(0, len(items) - 1)]
                else:
                    action.object = obj
        elif self.baseline_mode:
            selected_from_utility = 0
            if self.demo_mode.classifier:
                features = s.to_vector()
                probs = self.classifiers[t_id_map[id]].predict_proba(
                    np.asarray(features).reshape(1, -1)).flatten().tolist()
                selection = random()
                cprob = 0
                action_label = '0:apple'
                for i in range(0, len(probs)):
                    cprob += probs[i]
                    if cprob >= selection:
                        action_label = self.classifiers[
                            t_id_map[id]].classes_[i]
                        break
                # Convert back to action
                result = action_label.split(':')
                action.action_type = int(result[0])
                if len(result) > 1:
                    action.object = result[1]
                    if action.object == 'apple':
                        if obj not in items:
                            action.object = items[randint(0, len(items) - 1)]
                        else:
                            action.object = obj
            elif self.demo_mode.plan_network:
                current_node = self.action_sequences[
                    t_id_map[id]].find_suitable_node(req.state,
                                                     ground_items=[obj])
                if current_node is None:
                    current_node = 'start'
                action_list = self.action_sequences[
                    t_id_map[id]].get_successor_actions(current_node,
                                                        req.state,
                                                        ground_items=[obj])
                # select action stochastically if we're in the network, select randomly otherwise
                if len(action_list) == 0:
                    # random
                    action = self.A[id][randint(0, len(self.A[id]) - 1)]
                    if action.object == 'apple':
                        if obj not in items:
                            action.object = items[randint(0, len(items) - 1)]
                        else:
                            action.object = obj
                else:
                    selection = random()
                    count = 0
                    selected_action = action_list[0]
                    for i in range(len(action_list)):
                        count += action_list[i][1]
                        if count >= selection:
                            selected_action = action_list[i]
                            break
                    action.action_type = selected_action[0].action_type
                    action.object = selected_action[0].action_object
                    if action.object == 'apple':
                        if obj not in items:
                            action.object = items[randint(0, len(items) - 1)]
                        else:
                            action.object = obj
            else:
                action = self.A[id][randint(0, len(self.A[id]) - 1)]
                if action.object == 'apple':
                    if obj not in items:
                        action.object = items[randint(0, len(items) - 1)]
                    else:
                        action.object = obj

        else:
            utilities = {}
            for a in self.A[id]:
                successors = self.T[t_id_map[id]].transition_function(s, a)
                u = 0
                for i in range(len(successors)):
                    p = successors[i][0]
                    s_prime = successors[i][1]
                    if s_prime in self.U[id]:
                        u += p * self.U[id][s_prime]
                    elif is_terminal(s_prime, amdp_id=id):
                        u += p * reward(s_prime, amdp_id=id)
                utilities[a] = u

            # print '\n---'
            # for key in utilities:
            #     print str(key)
            #     print 'utility: ' + str(utilities[key])

            # pick top action deterministically
            max_utility = -999999
            for a in utilities.keys():
                if utilities[a] > max_utility:
                    max_utility = utilities[a]
                    action_list = []
                    action = deepcopy(a)
                    if action.object == 'apple':
                        if obj not in items:
                            action.object = items[randint(0, len(items) - 1)]
                        else:
                            action.object = obj
                    action_list.append(deepcopy(action))
                elif utilities[a] == max_utility:
                    action = deepcopy(a)
                    if action.object == 'apple':
                        if obj not in items:
                            action.object = items[randint(0, len(items) - 1)]
                        else:
                            action.object = obj
                    action_list.append(deepcopy(action))
                if debug > 1:
                    print 'Action: ', a.action_type, ':', a.object, ', Utility: ', utilities[
                        a]

            if max_utility != 0 and max_utility > 0:  # there is a successor state is in the utility table
                i = randint(0, len(action_list) - 1)
                # i = 0
                action = action_list[i]
                if debug > 0:
                    print('Action selected from utilities')
            else:  # we need to select an action a different way
                selected_from_utility = 0
                if self.demo_mode.plan_network and not self.demo_mode.classifier:
                    current_node = self.action_sequences[
                        t_id_map[id]].find_suitable_node(req.state,
                                                         ground_items=[obj])
                    if current_node is None:
                        current_node = 'start'
                    action_list = self.action_sequences[
                        t_id_map[id]].get_successor_actions(current_node,
                                                            req.state,
                                                            ground_items=[obj])

                    # select action stochastically if we're in the network, select randomly otherwise
                    if len(action_list) == 0:
                        # random
                        action = self.A[id][randint(0, len(self.A[id]) - 1)]
                        if action.object == 'apple':
                            if obj not in items:
                                action.object = items[randint(
                                    0,
                                    len(items) - 1)]
                            else:
                                action.object = obj
                    else:
                        selection = random()
                        count = 0
                        selected_action = action_list[0]
                        for i in range(len(action_list)):
                            count += action_list[i][1]
                            if count >= selection:
                                selected_action = action_list[i]
                                break
                        action.action_type = selected_action[0].action_type
                        action.object = selected_action[0].action_object
                        if action.object == 'apple':
                            if obj not in items:
                                action.object = items[randint(
                                    0,
                                    len(items) - 1)]
                            else:
                                action.object = obj
                elif self.demo_mode.plan_network and self.demo_mode.classifier:
                    # 50/50 tradeoff between plan network and classifier
                    use_plan_network = random() < 0.5
                    use_classifier = not use_plan_network

                    if use_plan_network:
                        current_node = self.action_sequences[
                            t_id_map[id]].find_suitable_node(
                                req.state, ground_items=[obj])
                        if current_node is None:
                            current_node = 'start'
                        action_list = self.action_sequences[
                            t_id_map[id]].get_successor_actions(
                                current_node, req.state, ground_items=[obj])

                        # select action stochastically if we're in the network, select with classifier otherwise
                        if len(action_list) == 0:
                            use_classifier = True
                        else:
                            selection = random()
                            count = 0
                            selected_action = action_list[0]
                            for i in range(len(action_list)):
                                count += action_list[i][1]
                                if count >= selection:
                                    selected_action = action_list[i]
                                    break
                            action.action_type = selected_action[0].action_type
                            action.object = selected_action[0].action_object
                            if action.object == 'apple':
                                if obj not in items:
                                    action.object = items[randint(
                                        0,
                                        len(items) - 1)]
                                else:
                                    action.object = obj

                            if debug > 0:
                                print('Action selected from plan network')

                    if use_classifier:
                        features = s.to_vector()
                        probs = self.classifiers[t_id_map[id]].predict_proba(
                            np.asarray(features).reshape(
                                1, -1)).flatten().tolist()
                        selection = random()
                        cprob = 0
                        action_label = '0:apple'
                        for i in range(0, len(probs)):
                            cprob += probs[i]
                            if cprob >= selection:
                                action_label = self.classifiers[
                                    t_id_map[id]].classes_[i]
                                break
                        # Convert back to action
                        result = action_label.split(':')
                        action.action_type = int(result[0])
                        if len(result) > 1:
                            action.object = result[1]
                            if action.object == 'apple':
                                if obj not in items:
                                    action.object = items[randint(
                                        0,
                                        len(items) - 1)]
                                else:
                                    action.object = obj
                        if debug > 0:
                            print('Action selected from classifier')

                elif self.demo_mode.classifier:
                    features = s.to_vector()

                    # if random() < 0.5:
                    probs = self.classifiers[t_id_map[id]].predict_proba(
                        np.asarray(features).reshape(1,
                                                     -1)).flatten().tolist()
                    selection = random()
                    cprob = 0
                    action_label = '0:apple'
                    for i in range(0, len(probs)):
                        cprob += probs[i]
                        if cprob >= selection:
                            action_label = self.classifiers[
                                t_id_map[id]].classes_[i]
                            break
                    # else:
                    #     probs = self.classifiers[t_id_map[id]].predict_proba(np.asarray(features).reshape(1, -1)).flatten().tolist()
                    #     selection = random()
                    #     cprob = 0
                    #     action_label = '0:apple'
                    #     for i in range(0, len(probs)):
                    #         cprob += probs[i]
                    #         if cprob >= selection:
                    #             action_label = self.classifiers[t_id_map[id]].classes_[i]
                    #             break

                    # Convert back to action
                    result = action_label.split(':')
                    action.action_type = int(result[0])
                    if len(result) > 1:
                        action.object = result[1]
                        if action.object == 'apple':
                            if obj not in items:
                                action.object = items[randint(
                                    0,
                                    len(items) - 1)]
                            else:
                                action.object = obj
                    if debug > 0:
                        print '***** Action selected from decision tree. *****'

                # random action
                # if self.demo_mode.random:
                else:
                    action = self.A[id][randint(0, len(self.A[id]) - 1)]
                    if action.object == 'apple':
                        if obj not in items:
                            action.object = items[randint(0, len(items) - 1)]
                        else:
                            action.object = obj

        if debug > 0:
            print '\t\tLow level action selection: ' + str(
                action.action_type) + ', ' + str(action.object)
        if action.action_type == Action.PLACE:
            if not self.continuous:
                action.position = DataUtils.semantic_action_to_position(
                    req.state, action.object)
                action.object = ''
        elif action.action_type == Action.MOVE_ARM:
            if not self.continuous:
                if action.object == 'l':
                    action.position.x = req.state.gripper_position.x - 10
                    action.position.y = req.state.gripper_position.y
                elif action.object == 'fl':
                    action.position.x = req.state.gripper_position.x - 10
                    action.position.y = req.state.gripper_position.y - 5
                elif action.object == 'f':
                    action.position.x = req.state.gripper_position.x
                    action.position.y = req.state.gripper_position.y - 5
                elif action.object == 'fr':
                    action.position.x = req.state.gripper_position.x + 10
                    action.position.y = req.state.gripper_position.y - 5
                elif action.object == 'r':
                    action.position.x = req.state.gripper_position.x + 10
                    action.position.y = req.state.gripper_position.y
                elif action.object == 'br':
                    action.position.x = req.state.gripper_position.x + 10
                    action.position.y = req.state.gripper_position.y + 5
                elif action.object == 'b':
                    action.position.x = req.state.gripper_position.x
                    action.position.y = req.state.gripper_position.y + 5
                elif action.object == 'bl':
                    action.position.x = req.state.gripper_position.x - 10
                    action.position.y = req.state.gripper_position.y + 5
                else:
                    action.position = DataUtils.semantic_action_to_position(
                        req.state, action.object)
                action.object = ''
        elif action.action_type != Action.GRASP:
            action.object = ''

        # print '\n\n-------------------'
        # print 'Selected action: '
        # print str(action)

        return action, selected_from_utility
Exemplo n.º 5
0
    def select_action(self, req):
        action = None

        action_list = []

        oo_state = OOState(state=req.state)
        s = RelationState(state=oo_state)
        print str(s.relations)

        utilities = {}
        for a in self.A:
            successors = transition_function(s, a)
            u = 0
            for i in range(len(successors)):
                p = successors[i][0]
                s_prime = successors[i][1]
                if s_prime in self.U:
                    u += p * self.U[s_prime]
                elif is_terminal(s_prime):
                    u += p * reward(s_prime)
            utilities[a] = u

        print '\n---'
        for key in utilities:
            print str(key)
            print 'utility: ' + str(utilities[key])

        # pick top action deterministically
        max_utility = -999999
        for a in utilities.keys():
            if utilities[a] > max_utility:
                max_utility = utilities[a]
                action_list = []
                action_list.append(deepcopy(a))
            elif utilities[a] == max_utility:
                action_list.append(deepcopy(a))

        if len(action_list) > 0:
            i = randint(0, len(action_list) - 1)
            action = action_list[i]
            if action.action_type == Action.PLACE:
                action.position = DataUtils.semantic_action_to_position(
                    req.state, action.object)
                action.object = ''
            elif action.action_type == Action.MOVE_ARM:
                if action.object == 'l':
                    action.position.x = req.state.gripper_position.x - 10
                    action.position.y = req.state.gripper_position.y
                elif action.object == 'fl':
                    action.position.x = req.state.gripper_position.x - 10
                    action.position.y = req.state.gripper_position.y - 5
                elif action.object == 'f':
                    action.position.x = req.state.gripper_position.x
                    action.position.y = req.state.gripper_position.y - 5
                elif action.object == 'fr':
                    action.position.x = req.state.gripper_position.x + 10
                    action.position.y = req.state.gripper_position.y - 5
                elif action.object == 'r':
                    action.position.x = req.state.gripper_position.x + 10
                    action.position.y = req.state.gripper_position.y
                elif action.object == 'br':
                    action.position.x = req.state.gripper_position.x + 10
                    action.position.y = req.state.gripper_position.y + 5
                elif action.object == 'b':
                    action.position.x = req.state.gripper_position.x
                    action.position.y = req.state.gripper_position.y + 5
                elif action.object == 'bl':
                    action.position.x = req.state.gripper_position.x - 10
                    action.position.y = req.state.gripper_position.y + 5
                else:
                    action.position = DataUtils.semantic_action_to_position(
                        req.state, action.object)
                action.object = ''
            elif action.action_type != Action.GRASP:
                action.object = ''
        else:
            action.action_type = Action.NOOP

        print '\n\n-------------------'
        print 'Selected action: '
        print str(action)

        return action
Exemplo n.º 6
0
    def read_all(self):
        print 'Parsing demonstrations for amdp_id ' + str(self.amdp_id)

        # Initialize data structures based on parse modes
        state_action_pairs = []
        prev_state = None
        prev_state_msg = None

        for demo_file in self.demo_list:
            print '\nReading ' + demo_file + '...'
            bag = rosbag.Bag(demo_file)
            for topic, msg, t in bag.read_messages(
                    topics=['/table_sim/task_log']):
                # Parse messages based on parse modes
                state = AMDPState(amdp_id=self.amdp_id,
                                  state=OOState(state=msg.state))

                if prev_state_msg is None:
                    prev_state_msg = copy.deepcopy(msg.state)

                if prev_state is None:
                    prev_state = state.to_vector()
                elif msg.action.action_type != Action.NOOP:
                    a = msg.action
                    # convert action into something that fits into the new action list
                    if a.action_type == Action.PLACE:
                        a.object = DataUtils.get_task_frame(
                            prev_state_msg, a.position)
                        a.position = Point()
                    elif a.action_type == Action.MOVE_ARM:
                        a.object = DataUtils.get_task_frame(
                            prev_state_msg, a.position)
                        if (self.amdp_id <= 2 and a.object != 'stack' and a.object != 'drawer') or \
                                (self.amdp_id >= 6 and a.object != 'box' and a.object != 'lid'):
                            for o in prev_state_msg.objects:
                                if o.name != 'apple':
                                    continue
                                if a.position == o.position:
                                    a.object = 'apple'
                                    break
                            if a.object != 'apple':
                                x = prev_state_msg.gripper_position.x
                                y = prev_state_msg.gripper_position.y
                                px = a.position.x
                                py = a.position.y
                                if px == x and py > y:
                                    a.object = 'b'
                                elif px < x and py > y:
                                    a.object = 'bl'
                                elif px < x and py == y:
                                    a.object = 'l'
                                elif px < x and py < y:
                                    a.object = 'fl'
                                elif px == x and py < y:
                                    a.object = 'f'
                                elif px > x and py < y:
                                    a.object = 'fr'
                                elif px > x and py == y:
                                    a.object = 'r'
                                else:
                                    a.object = 'br'
                        a.position = Point()
                    elif a.action_type == Action.GRASP:
                        a.position = Point()
                    else:
                        a.position = Point()
                        a.object = ''

                    pair = {
                        'state': copy.deepcopy(prev_state),
                        'action': str(a.action_type) + ':' + a.object
                    }
                    state_action_pairs.append(pair)

                    # update stored data for next iteration
                    prev_state_msg = copy.deepcopy(msg.state)
                    prev_state = state.to_vector()

            bag.close()

        # Write out data files
        self.write_yaml(state_action_pairs, 'amdp_sa')
Exemplo n.º 7
0
    def run(self):
        state_msg = self.query_state().state
        s = AMDPState(amdp_id=self.amdp_id, state=OOState(state=state_msg))

        self.timeout += 1

        goal_reached = goal_check(state_msg, self.amdp_id)
        if self.timeout > self.max_episode_length or goal_reached:
            self.timeout = 0
            # self.reset_sim()
            self.epoch += 1
            if goal_reached:
                self.successes += 1
            if self.demo_mode.plan_network:
                self.current_node = 'start'
                self.prev_state_msg = None
                self.prev_action = None
            return

        exploit_check = random()
        if self.exploit_policy and exploit_check > self.exploit_epsilon:
            a = self.select_action(state_msg, Action()).action
        else:
            # plan network exploration, behavior implemented individually to stop conditionals from getting crazy
            if self.demo_mode.plan_network:
                # determine the current node in the plan network
                if self.prev_state_msg is None or self.prev_action is None:
                    self.current_node = 'start'
                else:
                    self.current_node = AMDPPlanAction(self.prev_state_msg,
                                                       self.prev_action,
                                                       state_msg, self.amdp_id)

                # select action
                a = Action()
                if self.demo_mode.classifier:
                    if random() < self.alpha:
                        action_list = []
                        if self.action_sequences.has_node(self.current_node):
                            action_list = self.action_sequences.get_successor_actions(
                                self.current_node, state_msg)
                        else:
                            self.current_node = self.action_sequences.find_suitable_node(
                                state_msg)
                            if self.current_node is not None:
                                action_list = self.action_sequences.get_successor_actions(
                                    self.current_node, state_msg)

                        # select action stochastically if we're in the network, select randomly otherwise
                        if len(action_list) == 0:
                            a = self.A[randint(0, len(self.A) - 1)]
                        else:
                            selection = random()
                            count = 0
                            selected_action = action_list[0]
                            for i in range(len(action_list)):
                                count += action_list[i][1]
                                if count >= selection:
                                    selected_action = action_list[i]
                                    break
                            a.action_type = selected_action[0].action_type
                            a.object = selected_action[0].action_object
                    else:
                        if self.demo_mode.classifier:
                            if self.demo_mode.random and random(
                            ) <= self.epsilon:
                                a = self.A[randint(0, len(self.A) - 1)]
                            else:
                                features = s.to_vector()

                                # Classify action
                                probs = self.action_bias.predict_proba(
                                    np.asarray(features).reshape(
                                        1, -1)).flatten().tolist()
                                selection = random()
                                cprob = 0
                                action_label = '0:apple'
                                for i in range(0, len(probs)):
                                    cprob += probs[i]
                                    if cprob >= selection:
                                        action_label = self.action_bias.classes_[
                                            i]
                                        break
                                # Convert back to action
                                a = Action()
                                result = action_label.split(':')
                                a.action_type = int(result[0])
                                if len(result) > 1:
                                    a.object = result[1]
                        else:
                            a = self.A[randint(0, len(self.A) - 1)]
                else:
                    # select from the plan network, with a chance of random exploration, and use random exploration when
                    # off of the network
                    if random() < self.alpha:
                        action_list = []
                        if self.action_sequences.has_node(self.current_node):
                            action_list = self.action_sequences.get_successor_actions(
                                self.current_node, state_msg)
                        else:
                            self.current_node = self.action_sequences.find_suitable_node(
                                state_msg)
                            if self.current_node is not None:
                                action_list = self.action_sequences.get_successor_actions(
                                    self.current_node, state_msg)

                        # select action stochastically if we're in the network, select randomly otherwise
                        if len(action_list) == 0:
                            a = self.A[randint(0, len(self.A) - 1)]
                        else:
                            selection = random()
                            count = 0
                            selected_action = action_list[0]
                            for i in range(len(action_list)):
                                count += action_list[i][1]
                                if count >= selection:
                                    selected_action = action_list[i]
                                    break
                            a.action_type = selected_action[0].action_type
                            a.object = selected_action[0].action_object
                    else:
                        a = self.A[randint(0, len(self.A) - 1)]

                self.prev_state_msg = state_msg  # store state for the next iteration
                self.prev_action = action_to_sim(deepcopy(a), state_msg)

            else:
                if self.demo_mode.shadow and s in self.pi:
                    if random() < self.alpha:
                        a = self.pi[s].select_action()
                    else:
                        a = self.A[randint(0, len(self.A) - 1)]
                else:
                    if self.demo_mode.classifier:
                        # if random() < self.alpha:
                        if self.demo_mode.random and random() <= self.epsilon:
                            a = self.A[randint(0, len(self.A) - 1)]
                        else:
                            features = s.to_vector()

                            # Classify action
                            probs = self.action_bias.predict_proba(
                                np.asarray(features).reshape(
                                    1, -1)).flatten().tolist()
                            selection = random()
                            cprob = 0
                            action_label = '0:apple'
                            for i in range(0, len(probs)):
                                cprob += probs[i]
                                if cprob >= selection:
                                    action_label = self.action_bias.classes_[i]
                                    break
                            # Convert back to action
                            a = Action()
                            result = action_label.split(':')
                            a.action_type = int(result[0])
                            if len(result) > 1:
                                a.object = result[1]
                        # else:
                        #     if self.demo_mode.random and random() <= self.epsilon:
                        #         a = self.A[randint(0, len(self.A) - 1)]
                        #     else:
                        #         features = s.to_vector()
                        #
                        #         # Classify action
                        #         probs = self.action_bias.predict_proba(np.asarray(features).reshape(1, -1)).flatten().tolist()
                        #         selection = random()
                        #         cprob = 0
                        #         action_label = '0:apple'
                        #         for i in range(0, len(probs)):
                        #             cprob += probs[i]
                        #             if cprob >= selection:
                        #                 action_label = self.action_bias.classes_[i]
                        #                 break
                        #         # Convert back to action
                        #         a = Action()
                        #         result = action_label.split(':')
                        #         a.action_type = int(result[0])
                        #         if len(result) > 1:
                        #             a.object = result[1]
                    else:
                        a = self.A[randint(0, len(self.A) - 1)]

        self.execute_action(action_to_sim(deepcopy(a), state_msg))
        s_prime = AMDPState(amdp_id=self.amdp_id,
                            state=OOState(state=self.query_state().state))
        self.action_executions += 1

        self.transition_function.update_transition(s, a, s_prime)
        self.n += 1
        self.prev_state = deepcopy(s)