class SolverAdjust:
    def __init__(self, scenario_file, perturb=False, gravity_dir=(-1,-1)):
        self.perturb = perturb
        self.gravity_dir = gravity_dir
        # create scenario
        self.width, self.height, self.immobile_objs, self.mobile_objs, self.manipulatable_obj, self.target_obj = loadScenario(
            scenario_file)
        self.scenario_file = scenario_file
        print 'solving scenario: ', scenario_file
        print 'gravity internal sim: ', gravity_dir

        self.evalReal = EvaluateQualiSol(scenario_file, gravity_dir)
        #self.evalReal = EvaluateQualiSol('./scenarios/mini2.json')
        #self.evalReal = EvaluateQualiSol('./scenarios/s2.json')
        self.objs_map = {}
        self.obj_min_angle_difference = {}
        self.first_real_sol = -1
        for obj in self.immobile_objs:
            if perturb:
                perturb_obj(obj)
            self.objs_map[obj['id']] = obj
            #self.obj_min_angle_difference[obj['id']] = [100,100]
            self.obj_min_angle_difference[obj['id']] = [100, 0]
        if perturb:
            sps(scenario_file, 5, self.immobile_objs)

        self.target_obj_id = self.target_obj['id']

        self.maxd = 3
        self.scenario = None
        self.qualitative_paths = {}
        self.qualitative_paths_actions = {}
        self.zones = []

        self.simulation_counter = 0
        self.round = 0
        quali_sim = qualitative_simulation(scenario_file)
        self.estimated_qualitative_paths = quali_sim.run()

        self.initial_zone = quali_sim.initial_zone
        self.graph = quali_sim.graph
        self.zones = quali_sim.zones
        self.zone_dic = {}
        self.zone_distance = {}
        for x in xrange(self.width):
            for y in xrange(self.height):
                self.zone_dic[(x, y)] = self.__find_zones(x, y)

        for i in xrange(len(self.zones) - 1):
            self.zone_distance[(i, i)] = 0
            for j in xrange(i + 1, len(self.zones)):
                distance = self.zones[i].distance(self.zones[j])
                self.zone_distance[(i, j)] = distance
                self.zone_distance[(j, i)] = distance
                self.zone_distance[(len(self.zones) - 1, len(self.zones) - 1)] = 0

    def __find_zones(self, x, y):
        for i in xrange(len(self.zones)):
            if self.zones[i].contains(Point(x, y)):
                return i
        return -1

    def show_scenario(self):
        scenario = Scenario_Generator(self.width, self.height, self.immobile_objs, self.mobile_objs,
                                      self.manipulatable_obj, self.target_obj, showRender=True)
        scenario.run()

        return

    def new_scenario(self):
        self.scenario = Scenario_Generator(self.width, self.height, self.immobile_objs, self.mobile_objs,
                                           self.manipulatable_obj, self.target_obj, gravity_dir=self.gravity_dir, showRender=False)

    def solve_with_rules_classify(self):
        self.round += 1
        all_paths = []
        classification = {}
        essential_contacts = set([])
        quali_paths = []
        sectors_score = []

        path_by_dir = {}
        path_bounces = {}
        path_first_bounces = {}
        possible_impulse_ranges = []
        for path_dir, path, essential_contact, bounce_pos_list in self.estimated_qualitative_paths:
            # print bounces_pos
            # heappush(all_paths, (len(bounces_pos), (path, bounces_pos)))
            # print path_dir
            comp_path = self.make_path_complete(path, self.graph)
            essential_contacts.add(essential_contact)
            if comp_path not in quali_paths:
                quali_paths.append(comp_path)

            if path_dir not in path_by_dir:
                path_by_dir[path_dir] = [comp_path]
                if len(bounce_pos_list) > 0:
                    path_bounces[path_dir] = [bounce_pos_list]
                    path_first_bounces[path_dir] = set([bounce_pos_list[0]])
                else:
                    path_first_bounces[path_dir] = set([])
                    path_bounces[path_dir] = []
            else:
                path_by_dir[path_dir].append(comp_path)
                if len(bounce_pos_list) > 0:
                    path_bounces[path_dir].append(bounce_pos_list)
                    path_first_bounces[path_dir].add(bounce_pos_list[0])

        sort_dirs = []
        for path_dir in path_bounces:
            # calculate average bounce distance
            bounce_pos_list = path_bounces[path_dir]
            average_distance = 0
            for bounce_pos in bounce_pos_list:
                total_distance = 0
                r = 0.6
                for i in xrange(len(bounce_pos) - 1):
                    total_distance += self.zone_distance[(bounce_pos[i], bounce_pos[i + 1])] * pow(1 + r, i)
                average_distance += total_distance
            if len(bounce_pos_list) == 0:
                average_distance = 0
            else:
                average_distance /= len(bounce_pos_list)
            print "path_dir: ", path_dir, "  ", average_distance

            heappush(sort_dirs, (average_distance, path_dir))
        '''
        while all_paths:
            path, bounces_pos = heappop(all_paths)
            print path, " bounces: ", bounces_pos
        '''
        while sort_dirs:
            # for path_dir in path_by_dir:
            distance, path_dir = heappop(sort_dirs)
            print "Test dir range: ", path_dir, " distance ", distance
            # bounces_count = 0

            # for bounces_pos in path_bounces[path_dir]:
            #    print bounces_pos
            # subdivide path_dir into 10 sectors
            # quali_paths = path_by_dir[path_dir]
            divided_sectors = self.divide_dir(path_dir)
            use_less_path = False
            num_iter = 15 #10  # 4
            detected_sols = []
            while num_iter > 0 and not use_less_path:
                num_iter -= 1
                bounces_count = 0

                for sector in divided_sectors:
                    num_samples = 10
                    impulse_range = (IMPULSE_RANGE_X, sector)
                    actions = sample_n_points_from_range(num_samples, impulse_range)
                    # print "test sector: ", sector
                    # print actions
                    for action in actions:
                        path, contacts_info, solved = self.find_qualitative_path_ptlike(action, self.initial_zone)
                        if solved:
                            print "solution in approx sim: ", action

                            real_traj, real_contacts, real_contacts_objs, real_solved = self.evalReal.trialshot_real(
                                action)
                            detected_sols.append(action)
                            if real_solved:
                                #print ' real: ', self.first_real_sol, '  not perturb: ', not self.perturb
                                if self.first_real_sol == -1:
                                    if not self.perturb:
                                        self.first_real_sol = self.simulation_counter
                                        print ' detect first real sol: ', action, self.simulation_counter
                                        exit()
                                    else:
                                        self.first_real_sol = self.evalReal.count

                                print "solution in real evnironment!!!!!", action

                            else:
                                # adjust
                                #self.adjust_approx_sim(real_traj, real_contacts, real_contacts_objs, action)
                                self.adjust_approx_sim_qualitative_path(real_traj, real_contacts, real_contacts_objs, action)

                            '''
                            real_qualitative_path = self.compute_qualitative_path(real_traj, self.initial_zone)
                            print "solution detected: ", action, "  ", self.simulation_counter
                            print "expected qualitative path: ", path, contacts_info
                            print "real qualitative path: ", real_qualitative_path, real_contacts
                            '''
                            continue

                        path = self.make_path_complete(path, self.graph)
                        for first_bounces in path_first_bounces[path_dir]:
                            if first_bounces in path:
                                bounces_count += bounces_count

                        for contact in contacts_info:
                            if contact[0] in essential_contacts:
                                path_str = str(self.make_path_complete(path, self.graph)).strip('[]')
                                # print "path after: " ,path
                                print "perturb_action: ", action
                                max_mu = self.perturb_action(action, path_str, essential_contact)
                                mu_list = np.random.normal(max_mu, 20, 100)

                                for mu in mu_list:
                                    _action = (mu, action[1])
                                    scenario = Scenario_Generator(self.width, self.height, self.immobile_objs,
                                                                  self.mobile_objs, self.manipulatable_obj,
                                                                  self.target_obj, showRender=False)
                                    scenario.apply_impulse_and_run(_action)
                                    self.simulation_counter += 1
                                    # print "sample action: ", _action
                                    if scenario.solved:
                                        print "solution detected in approx ", _action, " ", self.simulation_counter
                                        detected_sols.append(action)
                                        real_traj, real_contacts, real_contacts_objs, real_solved = self.evalReal.trialshot_real(
                                            action)
                                        if real_solved:
                                            if self.first_real_sol == -1:
                                                if not self.perturb:
                                                    self.first_real_sol = self.simulation_counter
                                                    print ' detect first real sol: ', self.simulation_counter
                                                    print exit()
                                                else:
                                                    self.first_real_sol = self.evalReal.count
                                            print "solution in real environment !!! ", action
                                            break
                                        else:
                                            # adjust
                                            #self.adjust_approx_sim(real_traj, real_contacts, real_contacts_objs, action)
                                            self.adjust_approx_sim_qualitative_path(real_traj, real_contacts, real_contacts_objs, action)
                                            break

                    '''
                    if bounces_count == 0:
                    print "exit at: ", 10 - num_iter
                    use_less_path = True
                    '''
            if len(detected_sols) >= 1:
                min_mu = 5000
                max_mu = 0
                min_angle = 7
                max_angle = 0
                for sol in detected_sols:
                    if sol[0] > max_mu:
                        max_mu = sol[0]
                    if sol[0] < min_mu:
                        min_mu = sol[0]
                    if sol[1] > max_angle:
                        max_angle = sol[1]
                    if sol[1] < min_angle:
                        min_angle = sol[1]
                if max_angle - max_angle < 0.1:
                    min_angle -= 0.1
                    max_angle += 0.1
                mu_range = (min_mu - 100, max_mu + 100)
                eval_impulse = [mu_range, (min_angle, max_angle)]
                possible_impulse_ranges.append(eval_impulse)

        for impulse_range in possible_impulse_ranges:
            print 'eval: ', impulse_range
            density, shots = self.evalReal.eval(1000, impulse_range)
            print 'density: ', density, 'first sol: ', shots , ' total shots: ', self.evalReal.count + shots
            if shots != -1 and self.first_real_sol == -1:
                self.first_real_sol = shots

        print 'num of trial shots: ', self.evalReal.count
        print 'simulation steps for finding the first sol: ', self.first_real_sol
        if self.first_real_sol == -1:
            if self.round >= 2:
                exit()
            gc.collect()
            self.simulation_counter = 0
            self.perturb = False
            self.width, self.height, self.immobile_objs, self.mobile_objs, self.manipulatable_obj, self.target_obj = loadScenario(
                scenario_file)
            self.solve_with_rules_classify()


        if self.perturb:
            sps(self.scenario_file, 6, self.immobile_objs)


    def divide_dir(self, path_dir):
        num_sectors = 20
        min_angle, max_angle = path_dir
        difference = (max_angle - min_angle) / num_sectors
        divided_sectors = [(min_angle + i * difference, min_angle + (i + 1) * difference) for i in xrange(num_sectors)]
        #        print "divided sectors: ", divided_sectors
        return divided_sectors

    # return the most similar path (least edit distance)
    def least_distance(self, path, paths):
        min_d = 999
        # min_path = []
        for quali_path in paths:
            distance = editdistance.eval(path[:10], quali_path[:10])
            # distance = editdistance.eval(path, quali_path)
            if distance < min_d:
                min_d = distance
                # min_path = quali_path
        # print path, "  ", min_path, "distance: ", min_d
        return min_d

    def similar_path(self, path_1, paths):
        min_d = 999
        result = path_1
        for path_2 in paths:
            distance = editdistance.eval(path_1[:len(path_2)], path_2)
            if distance < min_d:
                min_d = distance
                result = path_2
        if min_d < 3:
            return path_2
        else:
            return None

    # maybe not necessary
    def make_path_complete(self, path, graph):
        new_path = []
        for i in xrange(len(path) - 1):
            new_path.append(path[i])
            if path[i + 1] == -1:
                break

            neighbors = graph.neighbors(path[i + 1])
            if path[i] not in neighbors and path[i] != -1:
                shortest_path = nx.shortest_path(graph, source=path[i], target=path[i + 1])
                # only retain the intermediate nodes (which are missing) along the path.
                new_path = new_path + shortest_path[1:-1]

        new_path.append(path[-1])
        return new_path

    def compute_qualitative_path(self, traj, initial_zone):
        pre_zone = initial_zone
        path = [initial_zone]
        for traj_pt in traj:
            x, y = traj_pt
            x = int(x)
            y = int(y)
            if (x, y) not in self.zone_dic:
                occupied_zone = -1
            else:
                occupied_zone = self.zone_dic[(x, y)]
            # if out of scope, still wait to see if it will come back, quite slow
            if occupied_zone == -1 or occupied_zone == pre_zone:
                continue
            path.append(occupied_zone)
            pre_zone = occupied_zone
        return path

    def find_qualitative_path_ptlike(self, action, initial_zone):
        # scenario = Scenario_Generator(self.width, self.height, self.immobile_objs, self.mobile_objs,self.manipulatable_obj, self.target_obj, showRender=False)
        self.new_scenario()
        self.scenario.apply_impulse_and_run(action)
        solved = self.scenario.solved
        '''
        if solved:
            scenario = Scenario_Generator(self.width, self.height, self.immobile_objs, self.mobile_objs, self.manipulatable_obj, self.target_obj, showRender=False)
            scenario.apply_impulse_and_run(action)
        '''
        self.simulation_counter += 1
        traj = self.scenario.find_man_traj()
        b2contacts = self.scenario.find_contacts_with_mobile_objs()

        # print contacts
        pre_zone = initial_zone
        path = [initial_zone]

        for traj_pt in traj:
            x, y = traj_pt
            x = int(x)
            y = int(y)
            if (x, y) not in self.zone_dic:
                occupied_zone = -1
            else:
                occupied_zone = self.zone_dic[(x, y)]
            # if out of scope, still wait to see if it will come back, quite slow
            if occupied_zone == -1 or occupied_zone == pre_zone:
                continue
            path.append(occupied_zone)
            pre_zone = occupied_zone

        return path, b2contacts, solved
    def find_qualitative_path(self, traj):
        pre_zone = self.initial_zone
        path = [self.initial_zone]
        for traj_pt in traj:
            x, y = traj_pt
            x = int(x)
            y = int(y)
            if (x, y) not in self.zone_dic:
                occupied_zone = -1
            else:
                occupied_zone = self.zone_dic[(x, y)]
            # if out of scope, still wait to see if it will come back, quite slow
            if occupied_zone == -1 or occupied_zone == pre_zone:
                continue
            path.append(occupied_zone)
            pre_zone = occupied_zone
        return path

    def bounce_similarity(self, es_bounces, real_bounces):
        bounce_differ_inx = 0
        for i, ind in enumerate(real_bounces):
            if i == len(es_bounces):
                break
            if ind != es_bounces[i]:
                bounce_differ_inx = i
                return bounce_differ_inx
        return bounce_differ_inx
    def perturb_action(self, action, root_path, essential_contact):

        mu, theta = action
        mu_step_size = 50
        num_runs = 100
        max_mu = IMPULSE_RANGE_X[-1]
        current_mu = mu
        # print "perturb_action: ", action
        while num_runs > 0 and mu_step_size > 0.005:
            num_runs -= 1
            action = (current_mu + mu_step_size, theta)
            path, contacts_info, solved = self.find_qualitative_path_ptlike(action, self.initial_zone)
            essential_contact_detected = False
            for contact in contacts_info:
                if contact[0] == essential_contact:
                    essential_contact_detected = True
                if solved:
                    print "solution detected when perturbing: ", action, "  ", self.simulation_counter
                    return current_mu + mu_step_size
            if not essential_contact_detected:
                continue

            path = str(self.make_path_complete(path, self.graph)).strip('[]')
            # print path, " ", action, " ", root_path in path, " max_mu: ", max_mu
            # if path != root_path:
            if root_path not in path:
                # qualitative turning point
                if max_mu > current_mu + mu_step_size:
                    max_mu = current_mu + mu_step_size
                mu_step_size /= 2
            else:

                if current_mu + mu_step_size >= IMPULSE_RANGE_X[-1]:
                    max_mu = IMPULSE_RANGE_X[-1]
                    break
                elif current_mu + mu_step_size > max_mu:
                    max_mu = current_mu + mu_step_size
                current_mu += mu_step_size
                mu_step_size = mu_step_size + mu_step_size

        return max_mu

    def adjust_approx_sim_qualitative_path(self, real_traj, real_bounces, real_contacted_objs, action=None):
        es_traj, es_bounces, es_contacted_objs = self.scenario.find_man_traj_bounce()
        es_num_of_bounces = len(es_contacted_objs)
        real_num_of_bounces = len(real_contacted_objs)
        real_path = self.find_qualitative_path(real_traj)

        real_path = str(real_path).strip('[]')
        real_bounces_str = str(real_bounces).strip('[]')
        #print "es bounce: ", es_bounces, " es_contacted ", es_contacted_objs, " real bounces: ", real_bounces, "  real_contacted: ", real_contacted_objs
        for i in xrange(es_num_of_bounces):
            for j in xrange(real_num_of_bounces):
                #if i == j and es_contacted_objs[i] == real_contacted_objs[j] and es_contacted_objs[i] != 0:
                if i==j and es_contacted_objs[i] == real_contacted_objs[j] and es_contacted_objs[i] != 0 and es_contacted_objs[i] in self.objs_map:
                    #if self.obj_min_angle_difference[es_contacted_objs[i]] != [100, 0]: # already adjusted
                    #    continue
                    es_path = self.find_qualitative_path(es_traj)
                    es_path = str(es_path).strip('[]')
                    es_bounces_str = str(es_bounces).strip('[]')
                    step = 0.001
                    obj = self.objs_map[es_contacted_objs[i]]
                    min_obj_angle = obj['angle']
                    count = 400
                    original_angle = obj['angle']

                    path_difference = editdistance.eval(es_path, real_path)
                    #bounce_difference = editdistance.eval(es_bounces_str, real_bounces_str)
                    bounce_similarity = self.bounce_similarity(es_bounces, real_bounces)

                    min_path_difference = path_difference
                    #min_bounce_difference = bounce_difference
                    max_bounce_similarity = bounce_similarity

                    if min_path_difference == 0:
                        print '!!!!! angle: ', obj['angle'], ' path difference: ', path_difference, ' bounce similarity ', bounce_similarity
                    while (path_difference > 0) and count >= 0:
                        count -= 1
                        self.new_scenario()
                        self.scenario.apply_impulse_and_run(action)
                        _es_traj, _es_bounces, _es_contacted_objs = self.scenario.find_man_traj_bounce()
                        _es_path = self.find_qualitative_path(_es_traj)
                        _es_path = str(_es_path).strip('[]')
                        _es_bounces_str = str(_es_bounces).strip('[]')

                        path_difference = editdistance.eval(_es_path, real_path)
                        bounce_similarity = self.bounce_similarity(_es_bounces, real_bounces)

                        #bounce_difference = editdistance.eval(_es_bounces_str, real_bounces_str)

                        #print _es_bounces_str, ' ||| ', real_bounces_str
                        #print ' path difference ', path_difference, 'bounce difference: ', bounce_difference, '  angle ', obj['angle']
                        #print ' path difference ', path_difference, 'bounce similairty: ', bounce_similarity, '  angle ', \
                        obj['angle']
                        '''
                        if path_difference <= min_path_difference:
                            if path_difference == min_path_difference:
                                if bounce_difference < min_bounce_difference:
                                    min_bounce_difference = bounce_difference
                                    min_path_difference = path_difference
                                    min_obj_angle = obj['angle']
                            else:
                                min_path_difference = path_difference
                                min_bounce_difference = bounce_difference
                                min_obj_angle = obj['angle']
                        '''
                        if path_difference <= min_path_difference:
                            if path_difference == min_path_difference:
                                if bounce_similarity > max_bounce_similarity:
                                    max_bounce_similarity = bounce_similarity
                                    min_path_difference = path_difference
                                    min_obj_angle = obj['angle']
                            else:
                                min_path_difference = path_difference
                                max_bounce_similarity = bounce_similarity
                                min_obj_angle = obj['angle']

                        if count < 200:
                            obj['angle'] += step
                        elif count == 200:
                            obj['angle'] = original_angle + step
                        else:
                            obj['angle'] -= step

                    if min_path_difference == self.obj_min_angle_difference[obj['id']][0] and max_bounce_similarity > \
                            self.obj_min_angle_difference[obj['id']][1]:
                        self.obj_min_angle_difference[obj['id']] = [min_path_difference, max_bounce_similarity]
                        obj['angle'] = min_obj_angle
                    elif min_path_difference < self.obj_min_angle_difference[obj['id']]:
                        self.obj_min_angle_difference[obj['id']] = [min_path_difference, max_bounce_similarity]
                        obj['angle'] = min_obj_angle
                    '''
                    if min_path_difference == self.obj_min_angle_difference[obj['id']][0] and min_bounce_difference < self.obj_min_angle_difference[obj['id']][1]:
                        self.obj_min_angle_difference[obj['id']] = [min_path_difference, min_bounce_difference]
                        obj['angle'] = min_obj_angle
                    elif min_path_difference < self.obj_min_angle_difference[obj['id']]:
                        self.obj_min_angle_difference[obj['id']] = [min_path_difference, min_bounce_difference]
                        obj['angle'] = min_obj_angle
                    '''
                    print 'obj_id: ', obj['id'], 'adjusted angle: ', min_obj_angle, 'path_d: ', min_path_difference, 'bounce_d', max_bounce_similarity, ' final angle: ', obj['angle'], ' final difference: ', self.obj_min_angle_difference[obj['id']]
                else:
                    continue
        return


    def adjust_approx_sim(self, real_traj, real_bounces, real_contacted_objs, action=None):
        es_traj, es_bounces, es_contacted_objs = self.scenario.find_man_traj_bounce()
        es_num_of_bounces = len(es_contacted_objs)
        real_num_of_bounces = len(real_contacted_objs)
        for i in xrange(es_num_of_bounces):
            for j in xrange(real_num_of_bounces):
                if i == j and es_contacted_objs[i] == real_contacted_objs[j] and es_contacted_objs[i] != 0:
                    es_bounce_pt_inx = es_bounces[i]
                    real_bounce_pt_inx = real_bounces[j]
                    #if len(es_traj) - 1 < es_bounce_pt_inx + 2 or real_bounce_pt_inx + 2 > len(real_traj) - 1:
                    #    break
                    es_bounces_pts = [es_traj[es_bounce_pt_inx - 1], es_traj[es_bounce_pt_inx],
                                      es_traj[es_bounce_pt_inx + 1]]
                    real_bounces_pts = [real_traj[real_bounce_pt_inx - 1], real_traj[real_bounce_pt_inx],
                                        real_traj[real_bounce_pt_inx + 1]]
                    angle_difference, translation = Adjust.find_spatial_difference(es_bounces_pts, real_bounces_pts)

                    # step = sigma_angle
                    step = 0.001
                    if angle_difference > 0:
                        neg_flag = False
                    else:
                        neg_flag = True
                    obj = self.objs_map[es_contacted_objs[i]]
                    #obj['position'] = (obj['position'][0] + translation[0], obj['position'][1] + translation[1])
                    # last_obj_angle = obj['angle']
                    current_obj_angle = obj['angle']
                    last_obj_angle = obj['angle']
                    last_angle_difference = angle_difference
                    min_angle_difference = 100
                    min_obj_angle = obj['angle']
                    count = 400
                    original_angle = obj['angle']
                    reversed = False
                    if abs(angle_difference) <= 0.001:
                        print ' !!!!!! obj angle', min_obj_angle, ' angfle difference ', angle_difference
                    while (abs(angle_difference) > 0.001) and count >= 0:
                        count -= 1
                        # update
                        #print 'angle differ: ', angle_difference, ' obj angle: ', obj['angle']
                        #print "translation: ", translation
                        #print 'current angle difference: ', angle_difference, " last angle difference", last_angle_difference, "  ", obj['angle']
                        # add a map record the last angle difference recorded for each adjusted object
                        '''
                        if not reversed:
                            obj['angle'] = last_obj_angle + step
                            last_obj_angle = obj['angle']
                            last_angle_difference = angle_difference
                        else:
                            obj['angle'] = last_obj_angle - step
                            last_obj_angle = obj['angle']
                            last_angle_difference = angle_difference

                        '''
                        '''
                        if angle_difference < 0:
                            obj['angle'] = last_obj_angle + step
                            last_obj_angle = obj['angle']
                            last_angle_difference = angle_difference

                        else:
                            obj['angle'] = last_obj_angle - step
                            last_obj_angle = obj['angle']
                            last_angle_difference = angle_difference
                        '''
                        '''
                        print 'angle difference: ', angle_difference, " current: ", current_obj_angle, "  last: ", last_obj_angle, "  neg flag: ", neg_flag
                        if angle_difference < 0:
                            if not neg_flag:
                                obj['angle'] = (last_obj_angle + current_obj_angle)/2
                                last_obj_angle = current_obj_angle
                                current_obj_angle = obj['angle']
                                neg_flag = True
                            else:
                                obj['angle'] += step
                                last_obj_angle = current_obj_angle
                                current_obj_angle = obj['angle']
                        else:
                            if neg_flag:
                                obj['angle'] = (last_obj_angle + current_obj_angle) / 2
                                last_obj_angle = current_obj_angle
                                current_obj_angle = obj['angle']
                                neg_flag = False
                            else:
                                obj['angle'] -= step
                                last_obj_angle = current_obj_angle
                                current_obj_angle = obj['angle']
                        '''

                        self.new_scenario()
                        self.scenario.apply_impulse_and_run(action)
                        _es_traj, _es_bounces, _es_contacted_objs = self.scenario.find_man_traj_bounce()
                        _es_bounces_pts = [_es_traj[es_bounce_pt_inx - 1], _es_traj[es_bounce_pt_inx],
                                           _es_traj[es_bounce_pt_inx + 1]]

                        angle_difference, translation = Adjust.find_spatial_difference(_es_bounces_pts,
                                                                                       real_bounces_pts)

                        if abs(angle_difference) < min_angle_difference:
                            min_angle_difference = abs(angle_difference)
                            min_obj_angle = obj['angle']

                        if count < 200:
                            obj['angle'] += step
                        elif count == 200:
                            obj['angle'] = original_angle + step
                        else:
                            obj['angle'] -= step
                        '''
                        if abs(angle_difference) < min_angle_difference:
                            min_angle_difference = abs(angle_difference)
                            min_obj_angle = obj['angle']
                            #step *= 2
                        else:
                            reversed = not reversed
                            step /= 2
                        '''
                            #if abs(translation[0]) < sigma_pos and abs(translation[1]) < sigma_pos:
                            #    obj['position'] = (obj['position'][0] + translation[0], obj['position'][1] + translation[1])


                    if min_angle_difference < self.obj_min_angle_difference[obj['id']]:
                        obj['angle'] = min_obj_angle
                        self.obj_min_angle_difference[obj['id']] = min_angle_difference
                    print 'adjusted angle: ', min_obj_angle, 'min_angle_difference: ', min_angle_difference, ' final angle: ', obj['angle'], ' final difference: ', self.obj_min_angle_difference[obj['id']]
                    break
                else:
                    continue
        return
    def solve_with_rules_classify(self):
        self.round += 1
        all_paths = []
        classification = {}
        essential_contacts = set([])
        quali_paths = []
        sectors_score = []

        path_by_dir = {}
        path_bounces = {}
        path_first_bounces = {}
        possible_impulse_ranges = []
        for path_dir, path, essential_contact, bounce_pos_list in self.estimated_qualitative_paths:
            # print bounces_pos
            # heappush(all_paths, (len(bounces_pos), (path, bounces_pos)))
            # print path_dir
            comp_path = self.make_path_complete(path, self.graph)
            essential_contacts.add(essential_contact)
            if comp_path not in quali_paths:
                quali_paths.append(comp_path)

            if path_dir not in path_by_dir:
                path_by_dir[path_dir] = [comp_path]
                if len(bounce_pos_list) > 0:
                    path_bounces[path_dir] = [bounce_pos_list]
                    path_first_bounces[path_dir] = set([bounce_pos_list[0]])
                else:
                    path_first_bounces[path_dir] = set([])
                    path_bounces[path_dir] = []
            else:
                path_by_dir[path_dir].append(comp_path)
                if len(bounce_pos_list) > 0:
                    path_bounces[path_dir].append(bounce_pos_list)
                    path_first_bounces[path_dir].add(bounce_pos_list[0])

        sort_dirs = []
        for path_dir in path_bounces:
            # calculate average bounce distance
            bounce_pos_list = path_bounces[path_dir]
            average_distance = 0
            for bounce_pos in bounce_pos_list:
                total_distance = 0
                r = 0.6
                for i in xrange(len(bounce_pos) - 1):
                    total_distance += self.zone_distance[(bounce_pos[i], bounce_pos[i + 1])] * pow(1 + r, i)
                average_distance += total_distance
            if len(bounce_pos_list) == 0:
                average_distance = 0
            else:
                average_distance /= len(bounce_pos_list)
            print "path_dir: ", path_dir, "  ", average_distance

            heappush(sort_dirs, (average_distance, path_dir))
        '''
        while all_paths:
            path, bounces_pos = heappop(all_paths)
            print path, " bounces: ", bounces_pos
        '''
        while sort_dirs:
            # for path_dir in path_by_dir:
            distance, path_dir = heappop(sort_dirs)
            print "Test dir range: ", path_dir, " distance ", distance
            # bounces_count = 0

            # for bounces_pos in path_bounces[path_dir]:
            #    print bounces_pos
            # subdivide path_dir into 10 sectors
            # quali_paths = path_by_dir[path_dir]
            divided_sectors = self.divide_dir(path_dir)
            use_less_path = False
            num_iter = 15 #10  # 4
            detected_sols = []
            while num_iter > 0 and not use_less_path:
                num_iter -= 1
                bounces_count = 0

                for sector in divided_sectors:
                    num_samples = 10
                    impulse_range = (IMPULSE_RANGE_X, sector)
                    actions = sample_n_points_from_range(num_samples, impulse_range)
                    # print "test sector: ", sector
                    # print actions
                    for action in actions:
                        path, contacts_info, solved = self.find_qualitative_path_ptlike(action, self.initial_zone)
                        if solved:
                            print "solution in approx sim: ", action

                            real_traj, real_contacts, real_contacts_objs, real_solved = self.evalReal.trialshot_real(
                                action)
                            detected_sols.append(action)
                            if real_solved:
                                #print ' real: ', self.first_real_sol, '  not perturb: ', not self.perturb
                                if self.first_real_sol == -1:
                                    if not self.perturb:
                                        self.first_real_sol = self.simulation_counter
                                        print ' detect first real sol: ', action, self.simulation_counter
                                        exit()
                                    else:
                                        self.first_real_sol = self.evalReal.count

                                print "solution in real evnironment!!!!!", action

                            else:
                                # adjust
                                #self.adjust_approx_sim(real_traj, real_contacts, real_contacts_objs, action)
                                self.adjust_approx_sim_qualitative_path(real_traj, real_contacts, real_contacts_objs, action)

                            '''
                            real_qualitative_path = self.compute_qualitative_path(real_traj, self.initial_zone)
                            print "solution detected: ", action, "  ", self.simulation_counter
                            print "expected qualitative path: ", path, contacts_info
                            print "real qualitative path: ", real_qualitative_path, real_contacts
                            '''
                            continue

                        path = self.make_path_complete(path, self.graph)
                        for first_bounces in path_first_bounces[path_dir]:
                            if first_bounces in path:
                                bounces_count += bounces_count

                        for contact in contacts_info:
                            if contact[0] in essential_contacts:
                                path_str = str(self.make_path_complete(path, self.graph)).strip('[]')
                                # print "path after: " ,path
                                print "perturb_action: ", action
                                max_mu = self.perturb_action(action, path_str, essential_contact)
                                mu_list = np.random.normal(max_mu, 20, 100)

                                for mu in mu_list:
                                    _action = (mu, action[1])
                                    scenario = Scenario_Generator(self.width, self.height, self.immobile_objs,
                                                                  self.mobile_objs, self.manipulatable_obj,
                                                                  self.target_obj, showRender=False)
                                    scenario.apply_impulse_and_run(_action)
                                    self.simulation_counter += 1
                                    # print "sample action: ", _action
                                    if scenario.solved:
                                        print "solution detected in approx ", _action, " ", self.simulation_counter
                                        detected_sols.append(action)
                                        real_traj, real_contacts, real_contacts_objs, real_solved = self.evalReal.trialshot_real(
                                            action)
                                        if real_solved:
                                            if self.first_real_sol == -1:
                                                if not self.perturb:
                                                    self.first_real_sol = self.simulation_counter
                                                    print ' detect first real sol: ', self.simulation_counter
                                                    print exit()
                                                else:
                                                    self.first_real_sol = self.evalReal.count
                                            print "solution in real environment !!! ", action
                                            break
                                        else:
                                            # adjust
                                            #self.adjust_approx_sim(real_traj, real_contacts, real_contacts_objs, action)
                                            self.adjust_approx_sim_qualitative_path(real_traj, real_contacts, real_contacts_objs, action)
                                            break

                    '''
                    if bounces_count == 0:
                    print "exit at: ", 10 - num_iter
                    use_less_path = True
                    '''
            if len(detected_sols) >= 1:
                min_mu = 5000
                max_mu = 0
                min_angle = 7
                max_angle = 0
                for sol in detected_sols:
                    if sol[0] > max_mu:
                        max_mu = sol[0]
                    if sol[0] < min_mu:
                        min_mu = sol[0]
                    if sol[1] > max_angle:
                        max_angle = sol[1]
                    if sol[1] < min_angle:
                        min_angle = sol[1]
                if max_angle - max_angle < 0.1:
                    min_angle -= 0.1
                    max_angle += 0.1
                mu_range = (min_mu - 100, max_mu + 100)
                eval_impulse = [mu_range, (min_angle, max_angle)]
                possible_impulse_ranges.append(eval_impulse)

        for impulse_range in possible_impulse_ranges:
            print 'eval: ', impulse_range
            density, shots = self.evalReal.eval(1000, impulse_range)
            print 'density: ', density, 'first sol: ', shots , ' total shots: ', self.evalReal.count + shots
            if shots != -1 and self.first_real_sol == -1:
                self.first_real_sol = shots

        print 'num of trial shots: ', self.evalReal.count
        print 'simulation steps for finding the first sol: ', self.first_real_sol
        if self.first_real_sol == -1:
            if self.round >= 2:
                exit()
            gc.collect()
            self.simulation_counter = 0
            self.perturb = False
            self.width, self.height, self.immobile_objs, self.mobile_objs, self.manipulatable_obj, self.target_obj = loadScenario(
                scenario_file)
            self.solve_with_rules_classify()


        if self.perturb:
            sps(self.scenario_file, 6, self.immobile_objs)
    def show_scenario(self):
        scenario = Scenario_Generator(self.width, self.height, self.immobile_objs, self.mobile_objs,
                                      self.manipulatable_obj, self.target_obj, showRender=True)
        scenario.run()

        return
 def new_scenario(self):
     self.scenario = Scenario_Generator(self.width, self.height, self.immobile_objs, self.mobile_objs,
                                        self.manipulatable_obj, self.target_obj, gravity_dir=self.gravity_dir, showRender=False)
from scenario_generator_adjust import Scenario_Generator as SG
import Adjust
from game_object import GameObject as GOBJ
# from path_finding import createGraph
from triangulation import triangulate_advanced
import numpy as np
import networkx as nx
from triangle_utils import triangle_center
from triangle_relation import printRels
from scenario_reader import loadScenario

# scenario_file = './scenarios/s3.json'
file_name = 's2p2'
scenario_width, scenario_height, immobile_objs, mobile_objs, manipulatable_obj, target_obj = loadScenario(
    './scenarios/' + file_name + '.json')
scenario = SG(scenario_width, scenario_height, immobile_objs, mobile_objs, manipulatable_obj, target_obj,
              showRender=False)
## get objects
game_objects = scenario.getGameObjects()
tri = []
graph, edges_index, edges_dirs, edges_by_tri, vertices_of_edges, edges_surface_dic, tri_by_surface, surface_rel_dic, tri_neighbor_by_edge, edges_by_object_id, objects_id_by_edge, edges_length, zones = triangulate_advanced(
    game_objects, scenario_width, scenario_height, tri)

plot.plot(plt.axes(), **(tri[0]))

################### Arrange graph according to the position of triangles #########

# triangles_by_vertices = tri["vertices"][tri["triangles"]]
pos = {}
for index, triangle in enumerate(zones):
    pos[index] = (triangle.centroid.x, triangle.centroid.y)