def project_to_surface(orb, times):
    """Takes list of positions and times and projects to body's surface."""
    
    bd = orb.prim
    positions = orb.get_positions(times=times)[0]
    
    bodyThetas = 2*np.pi/bd.rotPeriod * np.array(times) + bd.rotIni
    sphericalPositions = cartesian_to_spherical(positions)
    
    surfaceCoords = sphericalPositions[:,1:]
    for ii in range(len(times)):
        longitude = surfaceCoords[ii,0] - Orbit.map_angle(bodyThetas[ii])
        count = 0
        while abs(longitude) > np.pi:
            longitude =  longitude - np.sign(longitude) * 2*np.pi
            count = count+1
            if count > 10:
                print('Why?')
                break
        surfaceCoords[ii,0] = longitude
    
    return surfaceCoords
def add_orbit(figure, orb, startTime, endTime=None, numPts=201,
              dateFormat=None, apses=False, nodes=False, fullPeriod=True,
              color=(255,255,255), name='', style='solid', fade=True,):
    
    if fade:
        fadedColor = fade_color(color,3)
    else:
        fadedColor = color
    
    period = orb.get_period()
    if fullPeriod and (endTime is None):
        endTime = startTime + period
    
    # start and end mean anomalies
    mStart = orb.get_mean_anomaly(startTime)
    if (period < (endTime-startTime) or fullPeriod) and (orb.ecc < 1):
        mEnd = mStart + 2*math.pi
    else:
        mEnd = mStart + 2*math.pi/period * (endTime-startTime)
    
    # get points clustered around apoapsis and periapsis
    if orb.ecc < 1:
        a = mStart - (mStart%math.pi)
        b = a + math.pi
    else:
        if mStart < 0:
            a = mStart
            b = 0
        else:
            a = 0
            b = mEnd
    # orbit crosses two apo/peri-apses
    if ((mEnd >= b + math.pi) and (orb.ecc < 1)):
        c = b + math.pi
        d = c + math.pi
        n = math.ceil(math.pi/(mEnd-mStart)*numPts)
        kStart = n*math.acos((a+b-2*mStart)/(b-a))/math.pi
        kEnd =   n*math.acos((c+d-2*mEnd)/(d-c))/math.pi
        ks1 = [[kStart]]
        ks1.append([*range(math.ceil(kStart), n)])
        ks1 = [k for sublist in ks1 for k in sublist]
        meanAnoms1 =                                                        \
            [0.5*(a+b) + 0.5*(b-a) *math.cos((n-k)/n*math.pi) for k in ks1];
        ks2 = [*range(0,n)]
        meanAnoms2 =                                                        \
            [0.5*(b+c) + 0.5*(c-b) *math.cos((n-k)/n*math.pi) for k in ks2];
        ks3 = [*range(0,math.ceil(kEnd))]
        ks3.append(kEnd)
        meanAnoms3 =                                                        \
            [0.5*(c+d) + 0.5*(d-c) *math.cos((n-k)/n*math.pi) for k in ks3];
        meanAnoms = np.append(meanAnoms1, meanAnoms2)
        meanAnoms = np.append(meanAnoms, meanAnoms3)
        times = startTime + period/(2*math.pi) * (meanAnoms - mStart)
    # orbit crosses one apo/peri-apsis
    elif mEnd > b:
        if orb.ecc < 1:
            c = b + math.pi
            n1 = math.ceil(math.pi/(mEnd-mStart)*numPts)
            n2 = n1
        else:
            c = mEnd
            n1 = math.ceil(abs(mStart/(mEnd-mStart))*numPts)
            n2 = math.ceil(abs(mEnd/(mEnd-mStart))*numPts)
        kStart = n1*math.acos((a+b-2*mStart)/(b-a))/math.pi
        kEnd =   n2*math.acos((b+c-2*mEnd)/(c-b))/math.pi
        ks1 = [[kStart]]
        ks1.append([*range(math.ceil(kStart), n1)])
        ks1 = [k for sublist in ks1 for k in sublist]
        meanAnoms1 =                                                        \
            [0.5*(a+b) + 0.5*(b-a) *math.cos((n1-k)/n1*math.pi) for k in ks1];
        ks2 = [*range(0,math.ceil(kEnd))]
        ks2.append(kEnd)
        meanAnoms2 =                                                        \
            [0.5*(b+c) + 0.5*(c-b) *math.cos((n2-k)/n2*math.pi) for k in ks2];
        meanAnoms = np.append(meanAnoms1, meanAnoms2)
        times = startTime + period/(2*math.pi) * (meanAnoms - mStart)
    # orbit crosses no apo/peri-apses
    else:
        if orb.ecc < 1:
            n = math.ceil(2*math.pi/(mEnd-mStart)*numPts)
        else:
            n = numPts
        kStart = n*math.acos((a+b-2*mStart)/(b-a))/math.pi
        kEnd = n*math.acos((a+b-2*mEnd)/(b-a))/math.pi
        ks = [[kStart]]
        ks.append([*range(math.ceil(kStart), math.ceil(kEnd))])
        ks.append([kEnd])
        ks = [k for sublist in ks for k in sublist]
        meanAnoms = np.array(                                               \
            [0.5*(a+b) + 0.5*(b-a) *math.cos((n-k)/n*math.pi) for k in ks]);
        meanAnoms = meanAnoms.flatten()
        times = startTime + period/(2*math.pi) * (meanAnoms - mStart)
    
    pos, vel = orb.get_positions(times = times)
    pos = np.transpose(pos)
    vel = np.transpose(vel)
    
    if orb.ecc<1:
        for ii, m in enumerate(meanAnoms):
            meanAnoms[ii] = Orbit.map_angle(m)
    
    if not dateFormat is None:
        day = dateFormat['day']
        year = dateFormat['year']
        
        cData = np.stack((norm(pos, axis = 0)/1000,
                          norm(vel, axis = 0),
                          np.floor(times/(3600*day*year))+1,
                          np.floor(times%(3600*day*year)/(day*3600)+1),
                          np.floor((times%(3600*day))/3600),
                          np.floor(((times%(3600*day))%3600)/60),
                          np.floor(((times%(3600*day))%3600)%60),
                          times,
                          meanAnoms),
                          axis=1);
        hoverLabel = "r = %{customdata[0]:.3e} km" + "<br>" +\
                     "v = %{customdata[1]:.3e} m/s" + "<br>" + "<br>" +\
                     "Year %{customdata[2]:.0f}, " +\
                     "Day %{customdata[3]:.0f} " +\
                     "%{customdata[4]:0>2d}" + ":" +\
                     "%{customdata[5]:0>2d}" + ":" +\
                     "%{customdata[6]:0>2d}" + "<br>" +\
                     "UT: %{customdata[7]:.3f} s" + "<br>" +\
                     "Mean Anomaly: %{customdata[8]:.5f} rad" + "<br>" + "<br>" +\
                     "Semi-major Axis = " + "{:.0f}".format(orb.a) + " m" + "<br>" +\
                     "Eccentricity = " + "{:.4f}".format(orb.ecc) + "<br>" +\
                     "Inclination = " + "{:.4f}".format(orb.inc*180/math.pi) + "°" + "<br>" +\
                     "Argument of the Periapsis = " + "{:.4f}".format(orb.argp*180/math.pi) + "°" + "<br>" +\
                     "Longitude of Ascending Node = " + "{:.4f}".format(orb.lan*180/math.pi) + "°" + "<br>" +\
                     "Mean Anomaly at Epoch = " + "{:.4f}".format(orb.mo) + " rad" + "<br>" +\
                     "Epoch = " + "{:.2f}".format(orb.epoch) + " s"
    else:
        cData = norm(pos, axis = 0)/1000
        hoverLabel = "r = %{customdata:.4e} km"
    
    figure.add_trace(go.Scatter3d(
        x = pos[0],
        y = pos[1],
        z = pos[2],
        customdata = cData,
        mode = "lines",
        line = dict(
            color = times,
            colorscale = [[0, 'rgb'+str(fadedColor)], 
                      [1, 'rgb'+str(color)]],
            dash = style
            ),
        hovertemplate = hoverLabel,
        name = name,
        showlegend = False,
        ))
    
    if nodes:
        add_nodes(figure, orb)
    if apses:
        add_apses(figure, orb)