Example #1
0
 def cb_controls(self, map_message):
     self.last_control = {
         gm.Symbol(str_symbol): v
         for str_symbol, v in zip(map_message.symbol, map_message.value)
         if str_symbol in self.str_controls
     }
     self.tracker.process_control(self.last_control)
Example #2
0
    def cb_external_js(self, value_msg):
        state_update = {
            gm.Symbol(s): v
            for s, v in zip(value_msg.symbol, value_msg.value)
        }

        with self._state_lock:
            self._state.update(state_update)
Example #3
0
    def __init__(self, km, controlled_symbols, resting_pose, camera_path=None):
        tucking_constraints = {}
        if resting_pose is not None:
            tucking_constraints = {
                f'tuck {s}': SC(p - s, p - s, 1, s)
                for s, p in resting_pose.items()
            }
            # print('Tuck state:\n  {}\nTucking constraints:\n  {}'.format('\n  '.join(['{}: {}'.format(k, v) for k, v in self._resting_pose.items()]), '\n  '.join(tucking_constraints.keys())))

        # tucking_constraints.update(self.taxi_constraints)

        self.use_camera = camera_path is not None
        if camera_path is not None:
            self._poi_pos = gm.Symbol('poi')
            poi = gm.point3(1.5, 0.5, 0.0) + gm.vector3(
                0, self._poi_pos * 2.0, 0)

            camera = km.get_data(camera_path)
            cam_to_poi = poi - gm.pos_of(camera.pose)
            lookat_dot = 1 - gm.dot_product(gm.x_of(camera.pose),
                                            cam_to_poi) / gm.norm(cam_to_poi)
            tucking_constraints['sweeping gaze'] = SC(-lookat_dot * 5,
                                                      -lookat_dot * 5, 1,
                                                      lookat_dot)

        symbols = set()
        for c in tucking_constraints.values():
            symbols |= gm.free_symbols(c.expr)

        joint_symbols = {
            s
            for s in symbols if gm.get_symbol_type(s) != gm.TYPE_UNKNOWN
        }
        controlled_symbols = {gm.DiffSymbol(s) for s in joint_symbols}

        hard_constraints = km.get_constraints_by_symbols(
            symbols.union(controlled_symbols))

        controlled_values, hard_constraints = generate_controlled_values(
            hard_constraints, controlled_symbols)
        controlled_values = depth_weight_controlled_values(
            km, controlled_values)

        self.qp = TQPB(hard_constraints, tucking_constraints,
                       controlled_values)
        self._start = Time.now()
    def __init__(self,
                 km,
                 lead_goal_constraints,
                 follower_goal_constraints,
                 t_leader=TQPB,
                 t_follower=TQPB,
                 f_gen_lead_cvs=None,
                 f_gen_follower_cvs=None,
                 visualizer=None,
                 controls_blacklist=set(),
                 transition_overrides=None):
        lead_symbols = union(
            gm.free_symbols(c.expr) for c in lead_goal_constraints.values())
        follower_symbols = union(
            gm.free_symbols(c.expr)
            for c in follower_goal_constraints.values())
        self.lead_symbols = lead_symbols
        self.follower_symbols = follower_symbols

        self.lead_controlled_symbols = {
            s
            for s in union(
                gm.get_diff_symbols(c.expr)
                for c in lead_goal_constraints.values())
            if s not in controls_blacklist
        }
        # Only update the symbols that are unique to the follower
        self.follower_controlled_symbols = {
            s
            for s in union(
                gm.get_diff_symbols(c.expr)
                for c in follower_goal_constraints.values()) if
            s not in controls_blacklist and gm.IntSymbol(s) not in lead_symbols
        }

        f_gen_lead_cvs = self.gen_controlled_values if f_gen_lead_cvs is None else f_gen_lead_cvs
        lead_cvs, \
        lead_constraints = f_gen_lead_cvs(km,
                                          km.get_constraints_by_symbols(self.lead_controlled_symbols.union({gm.IntSymbol(s) for s in self.lead_controlled_symbols})),
                                          self.lead_controlled_symbols)

        f_gen_follower_cvs = self.gen_controlled_values if f_gen_follower_cvs is None else f_gen_follower_cvs
        follower_cvs, \
        follower_constraints = f_gen_follower_cvs(km,
                                                  km.get_constraints_by_symbols(self.follower_controlled_symbols.union({gm.IntSymbol(s) for s in self.follower_controlled_symbols})),
                                                  self.follower_controlled_symbols)

        if issubclass(t_leader, GQPB):
            lead_world = km.get_active_geometry(lead_symbols)
            self.lead_qp = t_leader(lead_world,
                                    lead_constraints,
                                    lead_goal_constraints,
                                    lead_cvs,
                                    visualizer=visualizer)
        else:
            self.lead_qp = t_leader(lead_constraints, lead_goal_constraints,
                                    lead_cvs)

        self.sym_dt = gm.Symbol('dT')
        self.lead_o_symbols, \
        self.lead_t_function, \
        self.lead_o_controls = generate_transition_function(self.sym_dt, lead_symbols, transition_overrides)

        self.follower_o_symbols, \
        self.follower_t_function, \
        self.follower_o_controls = generate_transition_function(self.sym_dt,
                                                                {gm.IntSymbol(s) for s in self.follower_controlled_symbols},
                                                                transition_overrides)

        self.follower_o_bounds = list(self.follower_controlled_symbols)
        follower_ctrl_bounds = [
            sum([[c.lower, c.upper]
                 for c in km.get_constraints_by_symbols({s}).values()], [])
            for s in self.follower_o_bounds
        ]

        max_bounds = max(len(row) for row in follower_ctrl_bounds)

        for s, row in zip(self.follower_o_bounds, follower_ctrl_bounds):
            row.extend([1e3] * (max_bounds - len(row)))
            print(f'{s}: {row}')

        follower_ctrl_bounds = gm.Matrix(follower_ctrl_bounds).T
        self.follower_ctrl_bounds_params = list(
            gm.free_symbols(follower_ctrl_bounds))
        self.follower_ctrl_bounds_f = gm.speed_up(
            follower_ctrl_bounds, self.follower_ctrl_bounds_params)

        self.follower_delta_map = {
            gm.IntSymbol(s): s
            for s in self.follower_controlled_symbols
        }

        if issubclass(t_follower, GQPB):
            follower_world = km.get_active_geometry(follower_symbols)
            self.follower_qp = t_follower(follower_world,
                                          follower_constraints,
                                          follower_goal_constraints,
                                          follower_cvs,
                                          visualizer=visualizer)
        else:
            self.follower_qp = t_follower(follower_constraints,
                                          follower_goal_constraints,
                                          follower_cvs)
                             world_frame='world',
                             wheel_radius=0.12 * 0.5,
                             wheel_distance=0.3748,
                             wheel_vel_limit=17.4)
        else:
            insert_omni_base(km, robot_path, robot_urdf.get_root(), 'world')
        km.clean_structure()
        km.dispatch_events()

    robot   = km.get_data(robot_path)
    nobilia = km.get_data(model_name)

    handle  = nobilia.links[rospy.get_param('~handle', 'handle')]
    integration_rules = None

    sym_dt = gm.Symbol('dT')

    if robot_name == 'pr2':
        joint_symbols, \
        robot_controlled_symbols, \
        start_pose, \
        eef, \
        blacklist = pr2_setup(km, robot_path)
    elif robot_name == 'fetch':
        joint_symbols, \
        robot_controlled_symbols, \
        start_pose, \
        eef, \
        blacklist = generic_setup(km, robot_path, rospy.get_param('~eef', 'gripper_link'))
        if rospy.get_param('~use_base', False):
            base_joint = robot.joints['to_world']
def create_nobilia_shelf(km,
                         prefix,
                         origin_pose=gm.eye(4),
                         parent_path=Path('world')):
    km.apply_operation(
        f'create {prefix}',
        ExecFunction(prefix, MarkedArticulatedObject, str(prefix)))

    shelf_height = 0.72
    shelf_width = 0.6
    shelf_body_depth = 0.35

    wall_width = 0.016

    l_prefix = prefix + ('links', )
    geom_body_wall_l = Box(
        l_prefix + ('body', ),
        gm.translation3(0, 0.5 * (shelf_width - wall_width), 0),
        gm.vector3(shelf_body_depth, wall_width, shelf_height))
    geom_body_wall_r = Box(
        l_prefix + ('body', ),
        gm.translation3(0, -0.5 * (shelf_width - wall_width), 0),
        gm.vector3(shelf_body_depth, wall_width, shelf_height))

    geom_body_ceiling = Box(
        l_prefix + ('body', ),
        gm.translation3(0, 0, 0.5 * (shelf_height - wall_width)),
        gm.vector3(shelf_body_depth, shelf_width - wall_width, wall_width))
    geom_body_floor = Box(
        l_prefix + ('body', ),
        gm.translation3(0, 0, -0.5 * (shelf_height - wall_width)),
        gm.vector3(shelf_body_depth, shelf_width - wall_width, wall_width))

    geom_body_shelf_1 = Box(
        l_prefix + ('body', ),
        gm.translation3(0.02, 0, -0.2 * (shelf_height - wall_width)),
        gm.vector3(shelf_body_depth - 0.04, shelf_width - wall_width,
                   wall_width))

    geom_body_shelf_2 = Box(
        l_prefix + ('body', ),
        gm.translation3(0.02, 0, 0.2 * (shelf_height - wall_width)),
        gm.vector3(shelf_body_depth - 0.04, shelf_width - wall_width,
                   wall_width))

    geom_body_back = Box(
        l_prefix + ('body', ),
        gm.translation3(0.5 * (shelf_body_depth - 0.005), 0, 0),
        gm.vector3(0.005, shelf_width - 2 * wall_width,
                   shelf_height - 2 * wall_width))

    shelf_geom = [
        geom_body_wall_l, geom_body_wall_r, geom_body_ceiling, geom_body_floor,
        geom_body_back, geom_body_shelf_1, geom_body_shelf_2
    ]

    rb_body = RigidBody(parent_path,
                        origin_pose,
                        geometry=dict(enumerate(shelf_geom)),
                        collision=dict(enumerate(shelf_geom)))

    geom_panel_top = Box(l_prefix + ('panel_top', ), gm.eye(4),
                         gm.vector3(0.357, 0.595, wall_width))
    geom_panel_bottom = Box(l_prefix + ('panel_bottom', ), gm.eye(4),
                            gm.vector3(0.357, 0.595, wall_width))

    handle_width = 0.16
    handle_depth = 0.05
    handle_diameter = 0.012

    geom_handle_r = Box(
        l_prefix + ('handle', ),
        gm.translation3(0.5 * handle_depth,
                        0.5 * (handle_width - handle_diameter), 0),
        gm.vector3(handle_depth, handle_diameter, handle_diameter))
    geom_handle_l = Box(
        l_prefix + ('handle', ),
        gm.translation3(0.5 * handle_depth,
                        -0.5 * (handle_width - handle_diameter), 0),
        gm.vector3(handle_depth, handle_diameter, handle_diameter))
    geom_handle_bar = Box(
        l_prefix + ('handle', ),
        gm.translation3(handle_depth - 0.5 * handle_diameter, 0, 0),
        gm.vector3(handle_diameter, handle_width - handle_diameter,
                   handle_diameter))

    handle_geom = [geom_handle_l, geom_handle_r, geom_handle_bar]

    # Sketch of mechanism
    #
    #           T ---- a
    #         ----      \  Z
    #       b ..... V    \
    #       |      ...... d
    #    B  |       ------
    #       c ------
    #                L
    #
    # Diagonal V is virtual
    #
    #
    # Angles:
    #   a -> alpha (given)
    #   b -> gamma_1 + gamma_2 = gamma
    #   c -> don't care
    #   d -> delta_1 + delta_2 = delta
    #

    opening_position = gm.Position(prefix + ('door', ))

    # Calibration results
    #
    # Solution top hinge: cost = 0.03709980624159568 [ 0.08762252 -0.01433833  0.2858676   0.00871125]
    # Solution bottom hinge: cost = 0.025004236048128934 [ 0.1072496  -0.01232362  0.27271013  0.00489996]

    # Added 180 deg rotation due to -x being the forward facing side in this model
    top_hinge_in_body_marker = gm.translation3(0.08762252 - 0.015, 0,
                                               -0.01433833)
    top_panel_marker_in_top_hinge = gm.translation3(0.2858676 - 0.003,
                                                    -wall_width + 0.0025,
                                                    0.00871125 - 0.003)
    front_hinge_in_top_panel_maker = gm.translation3(0.1072496 - 0.02, 0,
                                                     -0.01232362 + 0.007)
    bottom_panel_marker_in_front_hinge = gm.translation3(
        0.27271013, 0, 0.00489996)

    # Top hinge - Data taken from observation
    body_marker_in_body = gm.dot(
        gm.rotation3_axis_angle(gm.vector3(0, 0, 1), math.pi),
        gm.translation3(0.5 * shelf_body_depth - 0.062,
                        -0.5 * shelf_width + 0.078, 0.5 * shelf_height))
    top_panel_marker_in_top_panel = gm.translation3(
        geom_panel_top.scale[0] * 0.5 - 0.062,
        -geom_panel_top.scale[1] * 0.5 + 0.062, geom_panel_top.scale[2] * 0.5)
    bottom_panel_marker_in_bottom_panel = gm.translation3(
        geom_panel_bottom.scale[0] * 0.5 - 0.062,
        -geom_panel_bottom.scale[1] * 0.5 + 0.062,
        geom_panel_bottom.scale[2] * 0.5)

    top_hinge_in_body = gm.dot(body_marker_in_body, top_hinge_in_body_marker)
    top_panel_in_top_hinge = gm.dot(
        top_panel_marker_in_top_hinge,
        gm.inverse_frame(top_panel_marker_in_top_panel))
    front_hinge_in_top_panel = gm.dot(top_panel_marker_in_top_panel,
                                      front_hinge_in_top_panel_maker)
    bottom_panel_in_front_hinge = gm.dot(
        bottom_panel_marker_in_front_hinge,
        gm.inverse_frame(bottom_panel_marker_in_bottom_panel))

    # Point a in body reference frame
    point_a = gm.dot(gm.diag(1, 0, 1, 1), gm.pos_of(top_hinge_in_body))
    point_d = gm.point3(-shelf_body_depth * 0.5 + 0.09, 0,
                        shelf_height * 0.5 - 0.192)
    # point_d  = gm.point3(-shelf_body_depth * 0.5 + gm.Symbol('point_d_x'), 0, shelf_height * 0.5 - gm.Symbol('point_d_z'))
    # Zero alpha along the vertical axis
    vec_a_to_d = gm.dot(point_d - point_a)
    alpha = gm.atan2(vec_a_to_d[0], -vec_a_to_d[2]) + opening_position

    top_panel_in_body = gm.dot(
        top_hinge_in_body,  # Translation hinge to body frame
        gm.rotation3_axis_angle(gm.vector3(0, 1, 0), -opening_position +
                                0.5 * math.pi),  # Hinge around y
        top_panel_in_top_hinge)
    front_hinge_in_body = gm.dot(top_panel_in_body, front_hinge_in_top_panel)

    # Point b in top panel reference frame
    point_b_in_top_hinge = gm.pos_of(
        gm.dot(gm.diag(1, 0, 1, 1), front_hinge_in_top_panel,
               top_panel_in_top_hinge))
    point_b = gm.dot(gm.diag(1, 0, 1, 1), gm.pos_of(front_hinge_in_body))
    # Hinge lift arm in body reference frame
    point_c_in_bottom_panel = gm.dot(
        gm.diag(1, 0, 1, 1),
        bottom_panel_marker_in_bottom_panel,
        gm.point3(-0.094, -0.034, -0.072),
        # gm.point3(-gm.Symbol('point_c_x'), -0.034, -gm.Symbol('point_c_z'))
    )
    point_c_in_front_hinge = gm.dot(
        gm.diag(1, 0, 1, 1),
        gm.dot(bottom_panel_in_front_hinge, point_c_in_bottom_panel))
    length_z = gm.norm(point_a - point_d)

    vec_a_to_b = point_b - point_a
    length_t = gm.norm(vec_a_to_b)
    length_b = gm.norm(point_c_in_front_hinge[:3])
    # length_l = gm.Symbol('length_l') # 0.34
    length_l = 0.372

    vec_b_to_d = point_d - point_b
    length_v = gm.norm(vec_b_to_d)
    gamma_1 = inner_triangle_angle(length_t, length_v, length_z)
    gamma_2 = inner_triangle_angle(length_b, length_v, length_l)

    top_panel_offset_angle = gm.atan2(point_b_in_top_hinge[2],
                                      point_b_in_top_hinge[0])
    bottom_offset_angle = gm.atan2(point_c_in_front_hinge[2],
                                   point_c_in_front_hinge[0])

    gamma = gamma_1 + gamma_2

    rb_panel_top = RigidBody(l_prefix + ('body', ),
                             gm.dot(rb_body.pose, top_panel_in_body),
                             top_panel_in_body,
                             geometry={0: geom_panel_top},
                             collision={0: geom_panel_top})

    # old offset: 0.5 * geom_panel_top.scale[2] + 0.03
    tf_bottom_panel = gm.dot(
        front_hinge_in_top_panel,
        gm.rotation3_axis_angle(
            gm.vector3(0, 1, 0),
            math.pi + bottom_offset_angle - top_panel_offset_angle),
        gm.rotation3_axis_angle(gm.vector3(0, -1, 0), gamma),
        bottom_panel_in_front_hinge)

    rb_panel_bottom = RigidBody(l_prefix + ('panel_top', ),
                                gm.dot(rb_panel_top.pose, tf_bottom_panel),
                                tf_bottom_panel,
                                geometry={0: geom_panel_bottom},
                                collision={0: geom_panel_bottom})

    handle_transform = gm.dot(
        gm.translation3(geom_panel_bottom.scale[0] * 0.5 - 0.08, 0,
                        0.5 * wall_width),
        gm.rotation3_axis_angle(gm.vector3(0, 1, 0), -math.pi * 0.5))
    rb_handle = RigidBody(l_prefix + ('panel_bottom', ),
                          gm.dot(rb_panel_bottom.pose, handle_transform),
                          handle_transform,
                          geometry={x: g
                                    for x, g in enumerate(handle_geom)},
                          collision={x: g
                                     for x, g in enumerate(handle_geom)})
    # Only debugging
    point_c = gm.dot(rb_panel_bottom.pose, point_c_in_bottom_panel)
    vec_b_to_c = point_c - point_b

    km.apply_operation(f'create {prefix}/links/body',
                       CreateValue(rb_panel_top.parent, rb_body))
    km.apply_operation(f'create {prefix}/links/panel_top',
                       CreateValue(rb_panel_bottom.parent, rb_panel_top))
    km.apply_operation(
        f'create {prefix}/links/panel_bottom',
        CreateValue(l_prefix + ('panel_bottom', ), rb_panel_bottom))
    km.apply_operation(f'create {prefix}/links/handle',
                       CreateValue(l_prefix + ('handle', ), rb_handle))
    km.apply_operation(
        f'create {prefix}/joints/hinge',
        ExecFunction(
            prefix + Path('joints/hinge'), RevoluteJoint,
            CPath(rb_panel_top.parent), CPath(rb_panel_bottom.parent),
            opening_position, gm.vector3(0, 1, 0), gm.eye(4), 0, 1.84, **{
                f'{opening_position}':
                Constraint(0 - opening_position, 1.84 - opening_position,
                           opening_position),
                f'{gm.DiffSymbol(opening_position)}':
                Constraint(-0.25, 0.25, gm.DiffSymbol(opening_position))
            }))
    m_prefix = prefix + ('markers', )
    km.apply_operation(
        f'create {prefix}/markers/body',
        ExecFunction(m_prefix + ('body', ), Frame,
                     CPath(l_prefix + ('body', )),
                     gm.dot(rb_body.pose,
                            body_marker_in_body), body_marker_in_body))
    km.apply_operation(
        f'create {prefix}/markers/top_panel',
        ExecFunction(m_prefix + ('top_panel', ), Frame,
                     CPath(l_prefix + ('panel_top', )),
                     gm.dot(rb_panel_top.pose, top_panel_marker_in_top_panel),
                     top_panel_marker_in_top_panel))
    km.apply_operation(
        f'create {prefix}/markers/bottom_panel',
        ExecFunction(
            m_prefix + ('bottom_panel', ), Frame,
            CPath(l_prefix + ('panel_bottom', )),
            gm.dot(rb_panel_bottom.pose, bottom_panel_marker_in_bottom_panel),
            bottom_panel_marker_in_bottom_panel))

    return NobiliaDebug(
        [
            top_hinge_in_body,
            gm.dot(
                top_hinge_in_body,
                gm.rotation3_axis_angle(gm.vector3(0, 1, 0),
                                        -opening_position + 0.5 * math.pi),
                top_panel_in_top_hinge, front_hinge_in_top_panel),
            body_marker_in_body,
            gm.dot(rb_panel_top.pose, top_panel_marker_in_top_panel),
            gm.dot(rb_panel_bottom.pose, bottom_panel_marker_in_bottom_panel)
        ], [(point_a, vec_a_to_d), (point_a, vec_a_to_b),
            (point_b, vec_b_to_d), (point_b, vec_b_to_c)],
        {
            'gamma_1':
            gamma_1,
            'gamma_1 check_dot':
            gamma_1 - gm.acos(
                gm.dot_product(-vec_a_to_b / gm.norm(vec_a_to_b),
                               vec_b_to_d / gm.norm(vec_b_to_d))),
            'gamma_1 check_cos':
            gamma_1 - inner_triangle_angle(
                gm.norm(vec_a_to_b), gm.norm(vec_b_to_d), gm.norm(vec_a_to_d)),
            'gamma_2':
            gamma_2,
            'gamma_2 check_dot':
            gamma_2 - gm.acos(
                gm.dot_product(vec_b_to_c / gm.norm(vec_b_to_c),
                               vec_b_to_d / gm.norm(vec_b_to_d))),
            'length_v':
            length_v,
            'length_b':
            length_b,
            'length_l':
            length_l,
            'position':
            opening_position,
            'alpha':
            alpha,
            'dist c d':
            gm.norm(point_d - point_c)
        }, {
            gm.Symbol('point_c_x'): 0.094,
            gm.Symbol('point_c_z'): 0.072,
            gm.Symbol('point_d_x'): 0.09,
            gm.Symbol('point_d_z'): 0.192,
            gm.Symbol('length_l'): 0.372
        })
class EKFModel(object):
    DT_SYM = gm.Symbol('dt')

    def __init__(self,
                 observations,
                 constraints,
                 Q=None,
                 transition_rules=None,
                 trim_threshold=None):
        """Sets up an EKF estimating the underlying state of a set of observations.
        
        Args:
            observations (dict): A dict of observations. Names are mapped to 
                                 any kind of symbolic expression/matrix
            constraints (dict): A dict of named constraints that govern the 
                                configuration space of the estimated quantities
            Q (matrix, optional): Process noise of the estimated quantities. 
                                  Note: Quantities are expected to be ordered alphabetically
            transition_rules (dict, optional): Maps symbols to their transition rule.
                                               Rules will be generated automatically, if not provided here.
        """
        state_vars = union([gm.free_symbols(o) for o in observations.values()])

        self.ordered_vars = [
            s for _, s in sorted((str(s), s) for s in state_vars)
        ]
        self.Q = Q if Q is not None else np.zeros(
            (len(self.ordered_vars), len(self.ordered_vars)))

        st_fn = {}
        for s in self.ordered_vars:
            st_fn[s] = gm.wrap_expr(s + gm.DiffSymbol(s) * EKFModel.DT_SYM)

        if transition_rules is not None:
            varset = set(self.ordered_vars).union(
                {gm.DiffSymbol(s)
                 for s in self.ordered_vars}).union({EKFModel.DT_SYM})
            for s, r in transition_rules.items():
                if s in st_fn:
                    if len(gm.free_symbols(r).difference(varset)) == 0:
                        st_fn[s] = gm.wrap_expr(r)
                    else:
                        print(
                            f'Dropping rule "{s}: {r}". Symbols missing from state: {gm.free_symbols(r).difference(varset)}'
                        )
        control_vars = union([gm.free_symbols(r) for r in st_fn.values()]) \
                        .difference(self.ordered_vars)            \
                        .difference({EKFModel.DT_SYM})
        self.ordered_controls = [
            s for _, s in sorted((str(s), s) for s in control_vars)
        ]

        # State as column vector n * 1
        temp_g_fn = gm.Matrix(
            [gm.extract_expr(st_fn[s]) for s in self.ordered_vars])
        self.g_fn = gm.speed_up(temp_g_fn, [EKFModel.DT_SYM] +
                                self.ordered_vars + self.ordered_controls)
        temp_g_prime_fn = gm.Matrix([[
            gm.extract_expr(st_fn[s][d]) if d in st_fn[s] else 0
            for d in self.ordered_controls
        ] for s in self.ordered_vars])
        self.g_prime_fn = gm.speed_up(temp_g_prime_fn,
                                      [EKFModel.DT_SYM] + self.ordered_vars +
                                      self.ordered_controls)

        self.obs_labels = []
        self.takers = []
        flat_obs = []
        for o_label, o in sorted(observations.items()):
            if gm.is_symbolic(o):
                if gm.is_matrix(o):
                    if type(o) == gm.GM:
                        components = zip(
                            sum([[(y, x) for x in range(o.shape[1])]
                                 for y in range(o.shape[0])], []), iter(o))
                    else:
                        components = zip(
                            sum([[(y, x) for x in range(o.shape[1])]
                                 for y in range(o.shape[0])], []),
                            o.T.elements())  # Casadi iterates vertically
                    indices = []
                    for coords, c in components:
                        if gm.is_symbolic(c):
                            self.obs_labels.append('{}_{}_{}'.format(
                                o_label, *coords))
                            flat_obs.append(gm.wrap_expr(c))
                            indices.append(coords[0] * o.shape[1] + coords[1])
                    if len(indices) > 0:
                        self.takers.append((o_label, indices))
                else:
                    self.obs_labels.append(o_label)
                    flat_obs.append(gm.wrap_expr(o))
                    self.takers.append((o_label, [0]))

        temp_h_fn = gm.Matrix([gm.extract_expr(o) for o in flat_obs])
        self.h_fn = gm.speed_up(temp_h_fn, self.ordered_vars)
        temp_h_prime_fn = gm.Matrix([[
            gm.extract_expr(o[d]) if d in o else 0
            for d in self.ordered_controls
        ] for o in flat_obs])
        self.h_prime_fn = gm.speed_up(temp_h_prime_fn, self.ordered_vars)

        state_constraints = {}
        for n, c in constraints.items():
            if gm.is_symbol(c.expr):
                s = gm.free_symbols(c.expr).pop()
                fs = gm.free_symbols(c.lower).union(gm.free_symbols(c.upper))
                if len(fs.difference({s})) == 0:
                    state_constraints[s] = (float(gm.subs(c.lower, {s: 0})),
                                            float(gm.subs(c.upper, {s: 0})))

        self.state_bounds = np.array([
            state_constraints[s]
            if s in state_constraints else [-np.pi, np.pi]
            for s in self.ordered_vars
        ])
        self.R = None  # np.zeros((len(self.obs_labels), len(self.obs_labels)))
        self.trim_threshold = trim_threshold

    def gen_control_vector(self, control_dict):
        """Generates a control vector from a given dict mapping symbols to values
        
        Args:
            control_dict (dict): Map symbol -> float

        Returns:
            np.ndarray: Vectorized constrol
        """
        return np.array([control_dict[s] for s in self.ordered_controls])

    def gen_obs_vector(self, obs_dict):
        """Generates a flat vector of observations from a dictionary of them.
        
        Args:
            obs_dict (dict): Dict of observations. Names need to match those 
                             given to __init__
        Returns:
            np.ndarray: Flat vector of observations
        """
        return np.hstack([np.take(obs_dict[l], i) for l, i in self.takers])

    def pandas_R(self):
        """Returns the measurement covariance matrix as pandas frame.
        
        Returns:
            pd.DataFrame: Measurement covariance as DataFrame if matrix is stored.
                          None otherwise.
        """
        if self.R is None:
            return None
        return pd.DataFrame(data=self.R,
                            index=self.obs_labels,
                            columns=self.obs_labels)

    @profile
    def predict(self, state_t, Sigma_t, control, dt=0.05):
        """Predicts the next state and covariance from current state, covariance, and
           control signal.
        
        Args:
            state_t (np.ndarray): Current state vector
            Sigma_t (np.ndarray): Current covariance
            control (nd.ndarray): Current control vector
            dt (float, optional): Size of the prediction step
        
        Returns:
            (np.ndarray, np.ndarray): predicted_state, predicted_covariance
        """
        params = np.hstack(([dt], state_t, control))
        F_t = self.g_prime_fn.call2(params)
        # print(f'state_t: {state_t}\nSigma_t:\n{Sigma_t}\ncontrol: {control}\nparams: {params}\nF_t: {F_t}')
        return self.g_fn.call2(params), F_t.dot(Sigma_t.dot(F_t.T)) + self.Q

    @profile
    def update(self, state_t, Sigma_t, obs_t):
        """Performs the kalman update.
        
        Args:
            state_t (np.ndarray): Current state
            Sigma_t (np.ndarray): Current covariance
            obs_t (np.ndarray): Current observation
        
        Returns:
            (np.ndarray, np.ndarray): Updated state, updated covariance
        
        Raises:
            Exception: Will raise an exception if the measurement covariance R is not set
        """
        if self.R is None:
            raise Exception('No noise model set for EKF model.')

        H_t = self.h_prime_fn.call2(state_t)

        h_t = self.h_fn.call2(state_t)

        # H_indices, H_t = reorder_H(H_t, np.diag(Sigma_t), h_t, obs_t)
        if self.trim_threshold is not None:
            idx_obs, H_t = trim_singularities(H_t, self.trim_threshold)
        else:
            idx_obs = range(H_t.shape[0])
        # These indices actually refer to the reordered H_t, not the original one
        # idx_obs = [H_indices[x] for x in idx_obs]

        # if not hasattr(self, 'lol'):
        #     pd.DataFrame(data=H_t).to_csv('H_t.csv')
        #     np.savetxt('R.csv', self.R)
        #     self.lol = True

        # After reducing the observation, R will not necessarily match anymore
        reduced_R = np.take(np.take(self.R, idx_obs, axis=0), idx_obs, axis=1)
        S_t = H_t.dot(Sigma_t.dot(H_t.T)) + reduced_R
        # print(f'H_t: {H_t}\nS_t:\n{S_t}')
        if np.linalg.det(S_t) != 0.0:
            K_t = Sigma_t.dot(H_t.T.dot(np.linalg.inv(S_t)))
            y_t = np.take(obs_t - h_t.flatten(), idx_obs).reshape(
                (len(idx_obs), 1))
            state_t = (state_t + K_t.dot(y_t)).flatten()
            Sigma_t = (np.eye(Sigma_t.shape[0]) - K_t.dot(H_t)).dot(Sigma_t)
            return state_t, Sigma_t
        else:
            print('Determinant of S_t is 0')
            return state_t, Sigma_t

    def set_R(self, R):
        """Set the measurement covariance matrix R
        
        Args:
            R (np.ndarray): Measurement covariance matrix
        """
        self.R = R

    @profile
    def generate_R(self, noisy_observations):
        """Generates the covariance matrix from a set of noisy observations.
           Once generated, the matrix will be stored in the model.
        
        Args:
            noisy_observations ([dict]): List of observation dictionaries.
        """
        obs = np.vstack(
            [self.gen_obs_vector(obs) for obs in noisy_observations])
        cov = np.cov(obs.T)
        self.set_R(cov)

    def spawn_particle(self):
        """Spawns a particle initialized to be in the center of the configuration space.
           The covariance is initialized as variance of the observed quantities.
        
        Returns:
            Particle: A particle modeling an estimate of the state of this model
        """
        return Particle(
            (self.state_bounds.T[0] + self.state_bounds.T[1]) * 0.5,
            np.diag(self.state_bounds.T[1] - self.state_bounds.T[0])**2)

    def __str__(self):
        return 'EKF estimating:\n  {}\nFrom:\n  {}\nWith controls:\n  {}'.format(
            '\n  '.join(str(s) for s in self.ordered_vars),
            '\n  '.join(l for l in self.obs_labels),
            '\n  '.join(str(c) for c in self.ordered_controls))
import math
import rospy
import numpy as np
import kineverse.gradients.gradient_math as gm
import dearpygui.dearpygui as dpg
import signal

from kineverse.model.geometry_model import GeometryModel, Path
from kineverse.visualization.bpb_visualizer import ROSBPBVisualizer

from kineverse_experiment_world.nobilia_shelf import create_nobilia_shelf

x_loc_sym = gm.Symbol('location_x')


def draw_shelves(shelf, params, world, visualizer, steps=4, spacing=1):
    state = params.copy()
    visualizer.begin_draw_cycle('world')
    panel_cos = gm.dot_product(gm.z_of(shelf.links['panel_top'].pose),
                               -gm.z_of(shelf.links['panel_bottom'].pose))

    for x, p in enumerate(np.linspace(0, 1.84, steps)):
        state[shelf.joints['hinge'].position] = p
        state[x_loc_sym] = (x - steps / 2) * spacing

        cos = gm.subs(panel_cos, state)
        print('Angle at {:>8.2f}: {:> 8.2f}'.format(
            np.rad2deg(p), np.rad2deg(np.arccos(cos.flatten()[0]))))

        world.update_world(state)
        visualizer.draw_world('world', world)
        data = timeline[-1]

    if x == 0:
        return getattr(data, value)

    p_data = timeline[x - 1]
    segment_length = (data.stamp - p_data.stamp).to_sec()
    if segment_length < 1e-4:
        return getattr(data, value)

    fac = (stamp - p_data.stamp).to_sec() / segment_length

    return (1 - fac) * getattr(p_data, value) + fac * getattr(data, value)


angle_hinge = gm.Symbol('angle_top')
x_hinge_in_parent, z_hinge_in_parent = [
    gm.Symbol(f'hinge_in_parent_{x}') for x in 'xz'
]
x_child_in_hinge, z_child_in_hinge = [
    gm.Symbol(f'child_in_hinge_{x}') for x in 'xz'
]

fwd_kinematic_hinge = gm.dot(
    gm.translation3(x_hinge_in_parent, 0, z_hinge_in_parent),
    gm.rotation3_axis_angle(gm.vector3(0, -1, 0), angle_hinge),
    gm.translation3(x_child_in_hinge, 0, z_child_in_hinge))
# we don't care about the location in y
fwd_kinematic_hinge_residual_tf = gm.speed_up(
    gm.dot(gm.diag(1, 0, 1, 1), fwd_kinematic_hinge),
    gm.free_symbols(fwd_kinematic_hinge))
    def __init__(self, km, observations, transition_rules=None, max_iterations=20, num_samples=7):
        """Sets up an EKF estimating the underlying state of a set of observations.
        
        Args:
            km (ArticulationModel): Articulation model to query for constraints
            observations (dict): A dict of observations. Names are mapped to 
                                 any kind of symbolic expression/matrix
            transition_rules (dict, optional): Maps symbols to their transition rule.
                                               Rules will be generated automatically, if not provided here.
        """
        state_vars = union([gm.free_symbols(o) for o in observations.values()])

        self.num_samples = num_samples

        self.ordered_vars,  \
        self.transition_fn, \
        self.transition_args = generate_transition_function(QPStateModel.DT_SYM, 
                                                            state_vars, 
                                                            transition_rules)
        self.command_vars = {s for s in self.transition_args 
                                if s not in state_vars and str(s) != str(QPStateModel.DT_SYM)}

        obs_constraints = {}
        obs_switch_vars = {}

        # State as column vector n * 1
        self.switch_vars = {}
        self._obs_state  = {}
        self.obs_vars  = {}
        self.takers = {}
        flat_obs    = []
        for o_label, o in sorted(observations.items()):
            if gm.is_symbolic(o):
                obs_switch_var = gm.Symbol(f'{o_label}_observed')
                self.switch_vars[o_label] = obs_switch_var
                if o_label not in obs_constraints:
                    obs_constraints[o_label] = {}
                if o_label not in self.obs_vars:
                    self.obs_vars[o_label] = []

                if gm.is_matrix(o):
                    if type(o) == gm.GM:
                        components = zip(sum([[(y, x) for x in range(o.shape[1])] 
                                                      for y in range(o.shape[0])], []), iter(o))
                    else:
                        components = zip(sum([[(y, x) for x in range(o.shape[1])] 
                                                      for y in range(o.shape[0])], []), o.T.elements()) # Casadi iterates vertically
                    indices = []
                    for coords, c in components:
                        if gm.is_symbolic(c):
                            obs_symbol = gm.Symbol('{}_{}_{}'.format(o_label, *coords))
                            obs_error  = gm.abs(obs_symbol - c)
                            constraint = SC(-obs_error - (1 - obs_switch_var) * 1e3,
                                            -obs_error + (1 - obs_switch_var) * 1e3, 1, obs_error)
                            obs_constraints[o_label][f'{o_label}:{Path(obs_symbol)}'] = constraint
                            self.obs_vars[o_label].append(obs_symbol)
                            indices.append(coords[0] * o.shape[1] + coords[1])

                    if len(indices) > 0:
                        self.takers[o_label] = indices
                else:
                    obs_symbol = gm.Symbol(f'{o_label}_value')
                    obs_error  = gm.abs(obs_symbol - c)
                    constraint = SC(-obs_error - obs_switch_var * 1e9, 
                                    -obs_error + obs_switch_var * 1e9, 1, obs_error)
                    obs_constraints[o_label][f'{o_label}:{Path(obs_symbol)}'] = constraint

                    self.obs_vars[o_label].append(obs_symbol)
                    self.takers[o_label] = [0]

        state_constraints = km.get_constraints_by_symbols(state_vars)

        cvs, hard_constraints = generate_controlled_values(state_constraints, 
                                                           {gm.DiffSymbol(s) for s in state_vars 
                                                                             if gm.get_symbol_type(s) != gm.TYPE_UNKNOWN})
        flat_obs_constraints = dict(sum([list(oc.items()) for oc in obs_constraints.values()], []))

        self.qp = TQPB(hard_constraints, flat_obs_constraints, cvs)

        st_bound_vars, st_bounds, st_unbounded = static_var_bounds(km, state_vars)
        self._state = {s: 0 for s in st_unbounded} # np.random.uniform(-1.0, 1.0) for s in st_unbounded}

        for vb, (lb, ub) in zip(st_bound_vars, st_bounds):
            self._state[vb] = np.random.uniform(lb, ub)

        self._state_buffer = []
        self._state.update({s: 0 for s in self.transition_args})
        self._obs_state = {s: 0 for s in sum(self.obs_vars.values(), [])}
        self._obs_count = 0
        self._stamp_last_integration = None
        self._max_iterations = 10
        self._current_error  = 1e9
class QPStateModel(object):
    DT_SYM = gm.Symbol('dt')

    def __init__(self, km, observations, transition_rules=None, max_iterations=20, num_samples=7):
        """Sets up an EKF estimating the underlying state of a set of observations.
        
        Args:
            km (ArticulationModel): Articulation model to query for constraints
            observations (dict): A dict of observations. Names are mapped to 
                                 any kind of symbolic expression/matrix
            transition_rules (dict, optional): Maps symbols to their transition rule.
                                               Rules will be generated automatically, if not provided here.
        """
        state_vars = union([gm.free_symbols(o) for o in observations.values()])

        self.num_samples = num_samples

        self.ordered_vars,  \
        self.transition_fn, \
        self.transition_args = generate_transition_function(QPStateModel.DT_SYM, 
                                                            state_vars, 
                                                            transition_rules)
        self.command_vars = {s for s in self.transition_args 
                                if s not in state_vars and str(s) != str(QPStateModel.DT_SYM)}

        obs_constraints = {}
        obs_switch_vars = {}

        # State as column vector n * 1
        self.switch_vars = {}
        self._obs_state  = {}
        self.obs_vars  = {}
        self.takers = {}
        flat_obs    = []
        for o_label, o in sorted(observations.items()):
            if gm.is_symbolic(o):
                obs_switch_var = gm.Symbol(f'{o_label}_observed')
                self.switch_vars[o_label] = obs_switch_var
                if o_label not in obs_constraints:
                    obs_constraints[o_label] = {}
                if o_label not in self.obs_vars:
                    self.obs_vars[o_label] = []

                if gm.is_matrix(o):
                    if type(o) == gm.GM:
                        components = zip(sum([[(y, x) for x in range(o.shape[1])] 
                                                      for y in range(o.shape[0])], []), iter(o))
                    else:
                        components = zip(sum([[(y, x) for x in range(o.shape[1])] 
                                                      for y in range(o.shape[0])], []), o.T.elements()) # Casadi iterates vertically
                    indices = []
                    for coords, c in components:
                        if gm.is_symbolic(c):
                            obs_symbol = gm.Symbol('{}_{}_{}'.format(o_label, *coords))
                            obs_error  = gm.abs(obs_symbol - c)
                            constraint = SC(-obs_error - (1 - obs_switch_var) * 1e3,
                                            -obs_error + (1 - obs_switch_var) * 1e3, 1, obs_error)
                            obs_constraints[o_label][f'{o_label}:{Path(obs_symbol)}'] = constraint
                            self.obs_vars[o_label].append(obs_symbol)
                            indices.append(coords[0] * o.shape[1] + coords[1])

                    if len(indices) > 0:
                        self.takers[o_label] = indices
                else:
                    obs_symbol = gm.Symbol(f'{o_label}_value')
                    obs_error  = gm.abs(obs_symbol - c)
                    constraint = SC(-obs_error - obs_switch_var * 1e9, 
                                    -obs_error + obs_switch_var * 1e9, 1, obs_error)
                    obs_constraints[o_label][f'{o_label}:{Path(obs_symbol)}'] = constraint

                    self.obs_vars[o_label].append(obs_symbol)
                    self.takers[o_label] = [0]

        state_constraints = km.get_constraints_by_symbols(state_vars)

        cvs, hard_constraints = generate_controlled_values(state_constraints, 
                                                           {gm.DiffSymbol(s) for s in state_vars 
                                                                             if gm.get_symbol_type(s) != gm.TYPE_UNKNOWN})
        flat_obs_constraints = dict(sum([list(oc.items()) for oc in obs_constraints.values()], []))

        self.qp = TQPB(hard_constraints, flat_obs_constraints, cvs)

        st_bound_vars, st_bounds, st_unbounded = static_var_bounds(km, state_vars)
        self._state = {s: 0 for s in st_unbounded} # np.random.uniform(-1.0, 1.0) for s in st_unbounded}

        for vb, (lb, ub) in zip(st_bound_vars, st_bounds):
            self._state[vb] = np.random.uniform(lb, ub)

        self._state_buffer = []
        self._state.update({s: 0 for s in self.transition_args})
        self._obs_state = {s: 0 for s in sum(self.obs_vars.values(), [])}
        self._obs_count = 0
        self._stamp_last_integration = None
        self._max_iterations = 10
        self._current_error  = 1e9

    def _integrate_state(self):
        now = Time.now()
        if self._stamp_last_integration is not None:
            dt = (now - self._stamp_last_integration).to_sec()
            self._state[QPStateModel.DT_SYM] = dt
            new_state = self.transition_fn.call2([self._state[x] for x in self.transition_args]).flatten()
            for x, (s, v) in enumerate(zip(self.ordered_vars, new_state)):
                delta = v - self._state[s]
                self._state[s] = v
                for state in self._state_buffer:
                    state[x] += delta 

        self._stamp_last_integration = now        

    def set_command(self, command):
        self._integrate_state()

        for s in self.command_vars:
            if s in command:
                self._state[s] = command[s]

    @profile
    def update(self, observation):
        
        self._integrate_state()

        for o_label, o_vars in self.obs_vars.items():
            if o_label in observation:
                self._obs_state[self.switch_vars[o_label]] = 1
                sub_obs = np.take(observation[o_label], self.takers[o_label])
                self._obs_state.update({s: v for s, v in zip(o_vars, sub_obs)})
            else:
                self._obs_state[self.switch_vars[o_label]] = 0

        self._obs_state.update(self._state)
        self._obs_state[QPStateModel.DT_SYM] = 0.5

        for x in range(self._max_iterations):
            cmd = self.qp.get_cmd(self._obs_state, deltaT=0.5)
            self._obs_state.update(cmd)
            new_state = self.transition_fn.call2([self._obs_state[s] for s in self.transition_args])
            self._obs_state.update({s: v for s, v in zip(self.ordered_vars, new_state)})
            if self.qp.equilibrium_reached():
                break

        # CMA update
        self._obs_count += 1
        # cma_n   = np.array([self._state[s] for s in self.ordered_vars])
        x_n_1   = np.array([self._obs_state[s] for s in self.ordered_vars]).flatten()

        self._state_buffer.append(x_n_1)
        if len(self._state_buffer) > self.num_samples:
            self._state_buffer = self._state_buffer[-self.num_samples:]

        state_mean = np.mean(self._state_buffer, axis=0)

        # print(f'cma_n: {cma_n}\nx_n_1: {x_n_1}')
        # cma_n_1 = cma_n + (x_n_1 - cma_n) / self._obs_count if self._obs_count > 1 else x_n_1
        self._state.update({s: v for s, v in zip(self.ordered_vars, state_mean)})
        return self.qp.latest_error

    def state(self):
        self._integrate_state()
        return {s: self._state[s] for s in self.ordered_vars}

    @property
    def latest_error(self):
        return self.qp.latest_error

    def __str__(self):
        return 'QP estimating:\n  {}\nFrom:\n  {}\nWith controls:\n  {}'.format(
                        '\n  '.join(str(s) for s in self.ordered_vars), 
                        '\n  '.join(l for l in sorted(str(s) for s in sum(self.obs_vars.values(), []))),
                        '\n  '.join(c for c in sorted(str(c) for c in self.command_vars)))