from scipy.spatial.transform import Rotation as R
import std_msgs.msg, sensor_msgs.msg

parser = argparse.ArgumentParser()
parser.add_argument("config",help="path to magjoint config file, eg config/single_magnet.yaml")
parser.add_argument("-g",help="generate magnetic field samples",action="store_true")
parser.add_argument("-s",help="steps at which the magnetic field shall be sampled",type=float,default=1.0)
parser.add_argument("-scale",help="scale the magnetic field in cloud visualization",type=float,default=1.0)
parser.add_argument("-m",help="model name to load, eg data/three_magnets.npz",default='data/three_magnets.npz')
parser.add_argument("-v",help="visualize only",action="store_true")
parser.add_argument("-p",help="predict",action="store_true")
parser.add_argument("-select", nargs='+', help="select which sensors", type=int,
                        default=[0,1,14,2,15,3,16,4,17,5,18,6,19,7,20,8,21,9,22,10,23,11,24,12,25,13])
args = parser.parse_args()

ball = magjoint.BallJoint(args.config)
print(args)

magnets = ball.gen_magnets()
if args.v:
    ball.plotMagnets(magnets)
    sys.exit()

rospy.init_node('magnetic_field_calibration',anonymous=True)


motor_target = rospy.Publisher('motor_target', Float32, queue_size=1)

if args.g: #record data
    motor_target.publish(0)
    rospy.sleep(1)
import numpy as np
from scipy.interpolate import griddata

if len(sys.argv) < 5:
    print(
        "\nUSAGE: ./magnetic_field_visualization.py ball_joint_config x_step y_step plot_magnet_arrangement scale, e.g. \n python3 magnetic_field_visualization.py two_magnets.yaml 10 10 1 0.1\n"
    )
    sys.exit()

balljoint_config = sys.argv[1]
x_step = int(sys.argv[2])
y_step = int(sys.argv[3])
plot_magnet_arrangement = sys.argv[4] == '1'
scale = float(sys.argv[5])

ball = magjoint.BallJoint(balljoint_config)

magnets = ball.gen_magnets()
if plot_magnet_arrangement:
    ball.plotMagnets(magnets)

grid_positions,positions,pos_offsets,angles,angle_offsets = [],[],[],[],[]
for i in np.arange(-math.pi + math.pi / 180 * x_step,
                   math.pi - math.pi / 180 * x_step, math.pi / 180 * x_step):
    for j in np.arange(-math.pi, math.pi, math.pi / 180 * y_step):
        grid_positions.append([i, j])
        positions.append([
            22 * math.sin(i) * math.cos(j), 22 * math.sin(i) * math.sin(j),
            22 * math.cos(i)
        ])
        pos_offsets.append([0, 0, 0])