def run(self): best_plotter = Plotter(1) particle_plotter = Plotter(2) for particle in self.particles: viol_cost = particle.calc_violation(particle.pos) aep = self.calc_aep(particle.pos) particle.fitness = self.calc_fitness(aep, viol_cost) if (self.g_best_fitness is None or particle.fitness > self.g_best_fitness): self.g_best_fitness = particle.fitness self.g_best_pos = particle.pos self.g_best_aep = aep particle.best_pos = particle.pos particle.best_aep = aep particle.best_fitness = particle.fitness self.log_data() for particle in self.particles: particle_plotter.plot(particle.pos) best_plotter.plot(self.g_best_pos) self.iter_count = 1 while (self.iter_count < self.max_iterations): self.iterate() self.log_data() self.iter_count = self.iter_count + 1 for particle in self.particles: particle_plotter.plot(particle.pos) best_plotter.plot(self.g_best_pos) self.log_data() self.log_file('results/' + str(round(self.g_best_aep, 4)) + "_" + str(round(self.g_best_fitness, 4)) + "_" + str(self.num_particles) + "_" + str(self.max_iterations) + '.csv')
def plot_ekf_data(output_dir, timestamps, gt_poses, gt_vels, est_poses, est_states, g_const=9.80665): # convert all to np timestamps = np.array([np.array(item) for item in timestamps], dtype=np.float64) gt_poses = np.array([np.array(item) for item in gt_poses], dtype=np.float64) gt_vels = np.array([np.array(item) for item in gt_vels], dtype=np.float64) est_poses = np.array([np.array(item) for item in est_poses], dtype=np.float64) est_states = np.array([np.array(item) for item in est_states], dtype=np.float64) est_positions = [] est_vels = [] est_rots = [] est_gravities = [] est_ba = [] est_bw = [] for i in range(0, len(est_poses)): pose = np.linalg.inv(est_poses[i]) g, C, r, v, bw, ba = IMUKalmanFilter.decode_state( torch.tensor(est_states[i])) est_positions.append(pose[0:3, 3]) est_rots.append(log_SO3(pose[0:3, 0:3])) est_vels.append(np.array(v)) est_gravities.append(np.array(g)) est_bw.append(np.array(bw)) est_ba.append(np.array(ba)) est_positions = np.squeeze(est_positions) est_vels = np.squeeze(est_vels) est_rots = np.squeeze(est_rots) est_gravities = np.squeeze(est_gravities) est_bw = np.squeeze(est_bw) est_ba = np.squeeze(est_ba) gt_rots = np.array([log_SO3(p[:3, :3]) for p in gt_poses]) gt_gravities = np.array([ gt_poses[i, 0:3, 0:3].transpose().dot([0, 0, g_const]) for i in range(0, len(gt_poses)) ]) est_rots[:, 0] = np.unwrap(est_rots[:, 0]) est_rots[:, 1] = np.unwrap(est_rots[:, 1]) est_rots[:, 2] = np.unwrap(est_rots[:, 2]) gt_rots[:, 0] = np.unwrap(gt_rots[:, 0]) gt_rots[:, 1] = np.unwrap(gt_rots[:, 1]) gt_rots[:, 2] = np.unwrap(gt_rots[:, 2]) plotter = Plotter(output_dir) plotter.plot(( [gt_poses[:, 0, 3], gt_poses[:, 1, 3]], [est_positions[:, 0], est_positions[:, 1]], ), "x [m]", "y [m]", "XY Plot", labels=["gt_poses", "est_pose"], equal_axes=True) plotter.plot(( [gt_poses[:, 0, 3], gt_poses[:, 2, 3]], [est_positions[:, 0], est_positions[:, 2]], ), "x [m]", "z [m]", "XZ Plot", labels=["gt_poses", "est_pose"], equal_axes=True) plotter.plot(( [gt_poses[:, 1, 3], gt_poses[:, 2, 3]], [est_positions[:, 1], est_positions[:, 2]], ), "y [m]", "z [m]", "YZ Plot", labels=["gt_poses", "est_pose"], equal_axes=True) plotter.plot(( [timestamps, gt_poses[:, 0, 3]], [timestamps, est_positions[:, 0]], ), "t [s]", "p [m]", "Pos X", labels=["gt", "est"]) plotter.plot(( [timestamps, gt_poses[:, 1, 3]], [timestamps, est_positions[:, 1]], ), "t [s]", "p [m]", "Pos Y", labels=["gt", "est"]) plotter.plot(( [timestamps, gt_poses[:, 2, 3]], [timestamps, est_positions[:, 2]], ), "t [s]", "p [m]", "Pos Z", labels=["gt", "est"]) plotter.plot(( [timestamps, gt_vels[:, 0]], [timestamps, est_vels[:, 0]], ), "t [s]", "v [m/s]", "Vel X", labels=["gt", "est"]) plotter.plot(( [timestamps, gt_vels[:, 1]], [timestamps, est_vels[:, 1]], ), "t [s]", "v [m/s]", "Vel Y", labels=["gt", "est"]) plotter.plot(( [timestamps, gt_vels[:, 2]], [timestamps, est_vels[:, 2]], ), "t [s]", "v [m/s]", "Vel Z", labels=["gt", "est"]) plotter.plot(( [timestamps, gt_rots[:, 0]], [timestamps, est_rots[:, 0]], ), "t [s]", "rot [rad]", "Rot X", labels=["gt", "est"]) plotter.plot(( [timestamps, gt_rots[:, 1]], [timestamps, est_rots[:, 1]], ), "t [s]", "rot [rad]", "Rot Y", labels=["gt", "est"]) plotter.plot(( [timestamps, gt_rots[:, 2]], [timestamps, est_rots[:, 2]], ), "t [s]", "rot [rad]", "Rot Z", labels=["gt", "set"]) plotter.plot(( [timestamps, gt_gravities[:, 0]], [timestamps, est_gravities[:, 0]], ), "t [s]", "accel [m/s^2]", "Gravity X", labels=["gt", "est"]) plotter.plot(( [timestamps, gt_gravities[:, 1]], [timestamps, est_gravities[:, 1]], ), "t [s]", "accel [m/s^2]", "Gravity Y", labels=["gt", "est"]) plotter.plot(( [timestamps, gt_gravities[:, 2]], [timestamps, est_gravities[:, 2]], ), "t [s]", "accel [m/s^2]", "Gravity Z", labels=["gt", "est"]) plotter.plot(([timestamps, est_bw[:, 0]], ), "t [s]", "w [rad/s]", "Gyro Bias X") plotter.plot(([timestamps, est_bw[:, 1]], ), "t [s]", "w [rad/s]", "Gyro Bias Y") plotter.plot(([timestamps, est_bw[:, 2]], ), "t [s]", "w [rad/s]", "Gyro Bias Z") plotter.plot(([timestamps, est_ba[:, 0]], ), "t [s]", "a [m/s^2]", "Accel Bias X") plotter.plot(([timestamps, est_ba[:, 1]], ), "t [s]", "a [m/s^2]", "Accel Bias Y") plotter.plot(([timestamps, est_ba[:, 2]], ), "t [s]", "a [m/s^2]", "Accel Bias Z")
def preprocess_kitti_raw(raw_seq_dir, output_dir, cam_subset_range, plot_figures=True): logger.initialize(working_dir=output_dir, use_tensorboard=False) logger.print("================ PREPROCESS KITTI RAW ================") logger.print("Preprocessing %s" % raw_seq_dir) logger.print("Output to: %s" % output_dir) logger.print("Camera images: %d => %d" % (cam_subset_range[0], cam_subset_range[1])) oxts_dir = os.path.join(raw_seq_dir, "oxts") image_dir = os.path.join(raw_seq_dir, "image_02") gps_poses = np.loadtxt(os.path.join(oxts_dir, "poses.txt")) gps_poses = np.array([np.vstack([np.reshape(p, [3, 4]), [0, 0, 0, 1]]) for p in gps_poses]) T_velo_imu = np.loadtxt(os.path.join(raw_seq_dir, "../T_velo_imu.txt")) T_cam_velo = np.loadtxt(os.path.join(raw_seq_dir, '../T_cam_velo.txt')) T_cam_imu = T_cam_velo.dot(T_velo_imu) # imu timestamps imu_timestamps = read_timestamps(os.path.join(oxts_dir, "timestamps.txt")) assert (len(imu_timestamps) == len(gps_poses)) # load image data cam_timestamps = read_timestamps(os.path.join(image_dir, "timestamps.txt")) image_paths = [os.path.join(image_dir, "data", p) for p in sorted(os.listdir(os.path.join(image_dir, "data")))] assert (len(cam_timestamps) == len(image_paths)) assert (cam_subset_range[0] >= 0 and cam_subset_range[1] < len(image_paths)) # the first camera timestamps must be between IMU timestamps assert (cam_timestamps[cam_subset_range[0]] >= imu_timestamps[0]) assert (cam_timestamps[cam_subset_range[1]] <= imu_timestamps[-1]) # take subset of the camera images int the range of images we are interested in image_paths = image_paths[cam_subset_range[0]: cam_subset_range[1] + 1] cam_timestamps = cam_timestamps[cam_subset_range[0]: cam_subset_range[1] + 1] # convert to local time reference in seconds cam_timestamps = (cam_timestamps - imu_timestamps[0]) / np.timedelta64(1, 's') imu_timestamps = (imu_timestamps - imu_timestamps[0]) / np.timedelta64(1, 's') # take a subset of imu data corresponds to camera images idx_imu_data_start = find_timestamps_in_between(cam_timestamps[0], imu_timestamps)[0] idx_imu_data_end = find_timestamps_in_between(cam_timestamps[-1], imu_timestamps)[1] imu_timestamps = imu_timestamps[idx_imu_data_start:idx_imu_data_end + 1] gps_poses = gps_poses[idx_imu_data_start:idx_imu_data_end + 1] # load IMU data from list of text files imu_data = [] imu_data_files = sorted(os.listdir(os.path.join(oxts_dir, "data"))) start_time = time.time() for i in range(idx_imu_data_start, idx_imu_data_end + 1): print("Loading IMU data files %d/%d (%.2f%%)" % (i + 1 - idx_imu_data_start, len(imu_timestamps), 100 * (i + 1 - idx_imu_data_start) / len(imu_timestamps)), end='\r') imu_data.append(np.loadtxt(os.path.join(oxts_dir, "data", imu_data_files[i]))) imu_data = np.array(imu_data) logger.print("\nLoading IMU data took %.2fs" % (time.time() - start_time)) assert (len(imu_data) == len(gps_poses)) # imu_data = imu_data[idx_imu_data_start:idx_imu_data_end + 1] imu_timestamps, imu_data, gps_poses = remove_negative_timesteps(imu_timestamps, imu_data, gps_poses) data_frames = [] start_time = time.time() idx_imu_slice_start = 0 idx_imu_slice_end = 0 for k in range(0, len(cam_timestamps) - 1): print("Processing IMU data files %d/%d (%.2f%%)" % ( k + 1, len(cam_timestamps), 100 * (k + 1) / len(cam_timestamps)), end='\r') t_k = cam_timestamps[k] t_kp1 = cam_timestamps[k + 1] # the start value does not need to be recomputed, since you can get that from the previous time step, but # i am a lazy person, this will work while imu_timestamps[idx_imu_slice_start] < t_k: idx_imu_slice_start += 1 assert (imu_timestamps[idx_imu_slice_start - 1] <= t_k <= imu_timestamps[idx_imu_slice_start]) # interpolate tk_i = imu_timestamps[idx_imu_slice_start - 1] tk_j = imu_timestamps[idx_imu_slice_start] alpha_k = (t_k - tk_i) / (tk_j - tk_i) T_i_vk, v_vk, w_vk, a_vk = \ interpolate(imu_data[idx_imu_slice_start - 1], imu_data[idx_imu_slice_start], gps_poses[idx_imu_slice_start - 1], gps_poses[idx_imu_slice_start], alpha_k) while imu_timestamps[idx_imu_slice_end] < t_kp1: idx_imu_slice_end += 1 assert (imu_timestamps[idx_imu_slice_end - 1] <= t_kp1 <= imu_timestamps[idx_imu_slice_end]) # interpolate tkp1_i = imu_timestamps[idx_imu_slice_end - 1] tkp1_j = imu_timestamps[idx_imu_slice_end] alpha_kp1 = (t_kp1 - tkp1_i) / (tkp1_j - tkp1_i) T_i_vkp1, v_vkp1, w_vkp1, a_vkp1 = \ interpolate(imu_data[idx_imu_slice_end - 1], imu_data[idx_imu_slice_end], gps_poses[idx_imu_slice_end - 1], gps_poses[idx_imu_slice_end], alpha_kp1) imu_timestamps_k_kp1 = np.concatenate( [[t_k], imu_timestamps[idx_imu_slice_start:idx_imu_slice_end - 1], [t_kp1]]) imu_poses = np.concatenate([[T_i_vk], gps_poses[idx_imu_slice_start:idx_imu_slice_end - 1], [T_i_vkp1]]) accel_measurements_k_kp1 = np.concatenate([[a_vk], imu_data[idx_imu_slice_start: idx_imu_slice_end - 1, ax:az + 1], [a_vkp1]]) gyro_measurements_k_kp1 = np.concatenate([[w_vk], imu_data[idx_imu_slice_start: idx_imu_slice_end - 1, wx:wz + 1], [w_vkp1]]) frame_k = SequenceData.Frame(image_paths[k], t_k, T_i_vk, v_vk, imu_poses, imu_timestamps_k_kp1, accel_measurements_k_kp1, gyro_measurements_k_kp1) data_frames.append(frame_k) # assertions for sanity check assert (np.allclose(data_frames[-1].timestamp, data_frames[-1].imu_timestamps[0], atol=1e-13)) assert (np.allclose(data_frames[-1].T_i_vk, data_frames[-1].imu_poses[0], atol=1e-13)) if len(data_frames) > 1: assert (np.allclose(data_frames[-1].timestamp, data_frames[-2].imu_timestamps[-1], atol=1e-13)) assert (np.allclose(data_frames[-1].T_i_vk, data_frames[-2].imu_poses[-1], atol=1e-13)) assert ( np.allclose(data_frames[-1].accel_measurements[0], data_frames[-2].accel_measurements[-1], atol=1e-13)) assert ( np.allclose(data_frames[-1].accel_measurements[0], data_frames[-2].accel_measurements[-1], atol=1e-13)) # add the last frame without any IMU data data_frames.append(SequenceData.Frame(image_paths[-1], t_kp1, T_i_vkp1, v_vkp1, np.zeros([0, 4, 4]), np.zeros([0]), np.zeros([0, 3]), np.zeros([0, 3]))) logger.print("\nProcessing data took %.2fs" % (time.time() - start_time)) df = SequenceData.save_as_pd(data_frames, np.array([0, 0, 9.808679801065017]), np.zeros(3), T_cam_imu, output_dir) data = df.to_dict("list") if not plot_figures: logger.print("All done!") return # ============================== FIGURES FOR SANITY TESTS ============================== # plot trajectory start_time = time.time() plotter = Plotter(output_dir) p_poses = np.array(data["T_i_vk"]) p_timestamps = np.array(data["timestamp"]) p_velocities = np.array(data["v_vk_i_vk"]) p_imu_timestamps = np.concatenate([d[:-1] for d in data['imu_timestamps']]) p_gyro_measurements = np.concatenate([d[:-1] for d in data['gyro_measurements']]) p_accel_measurements = np.concatenate([d[:-1] for d in data["accel_measurements"]]) p_imu_poses = np.concatenate([d[:-1, :, :] for d in data["imu_poses"]]) assert (len(p_imu_timestamps) == len(p_gyro_measurements)) assert (len(p_imu_timestamps) == len(p_accel_measurements)) assert (len(p_imu_timestamps) == len(p_imu_poses)) # integrate accel to compare against velocity p_accel_int = [p_velocities[0, :]] p_accel_int_int = [p_poses[0, :3, 3]] # g = np.array([0, 0, 9.80665]) g = np.array([0, 0, 9.808679801065017]) # g = np.array([0, 0, 9.8096]) for i in range(0, len(p_imu_timestamps) - 1): dt = p_imu_timestamps[i + 1] - p_imu_timestamps[i] C_i_vk = p_imu_poses[i, :3, :3] C_vkp1_vk = p_imu_poses[i + 1, :3, :3].transpose().dot(p_imu_poses[i, :3, :3]) v_vk_i_vk = p_accel_int[-1] v_vkp1_vk_vk = dt * (p_accel_measurements[i] - C_i_vk.transpose().dot(g)) v_vkp1_i_vk = v_vk_i_vk + v_vkp1_vk_vk p_accel_int.append(C_vkp1_vk.dot(v_vkp1_i_vk)) p_accel_int_int.append(p_accel_int_int[-1] + p_imu_poses[i, :3, :3].dot(p_accel_int[-1]) * dt) p_accel_int = np.array(p_accel_int) p_accel_int_int = np.array(p_accel_int_int) # poses from integrating velocity p_vel_int_poses = [p_poses[0, :3, 3]] for i in range(0, len(p_velocities) - 1): dt = p_timestamps[i + 1] - p_timestamps[i] dp = p_poses[i, :3, :3].dot(p_velocities[i]) * dt p_vel_int_poses.append(p_vel_int_poses[-1] + dp) p_vel_int_poses = np.array(p_vel_int_poses) plotter.plot(([p_poses[:, 0, 3], p_poses[:, 1, 3]], [p_vel_int_poses[:, 0], p_vel_int_poses[:, 1]], [p_accel_int_int[:, 0], p_accel_int_int[:, 1]],), "x [m]", "Y [m]", "XY Plot", labels=["dat_poses", "dat_vel_int", "dat_acc_int^2"], equal_axes=True) plotter.plot(([p_poses[:, 0, 3], p_poses[:, 2, 3]], [p_vel_int_poses[:, 0], p_vel_int_poses[:, 2]], [p_accel_int_int[:, 0], p_accel_int_int[:, 2]],), "X [m]", "Z [m]", "XZ Plot", labels=["dat_poses", "dat_vel_int", "dat_acc_int^2"], equal_axes=True) plotter.plot(([p_poses[:, 1, 3], p_poses[:, 2, 3]], [p_vel_int_poses[:, 1], p_vel_int_poses[:, 2]], [p_accel_int_int[:, 1], p_accel_int_int[:, 2]],), "Y [m]", "Z [m]", "YZ Plot", labels=["dat_poses", "dat_vel_int", "dat_acc_int^2"], equal_axes=True) plotter.plot(([p_timestamps, p_poses[:, 0, 3]], [p_timestamps, p_vel_int_poses[:, 0]], [p_imu_timestamps, p_accel_int_int[:, 0]],), "t [s]", "Y [m]", "X Plot From Zero", labels=["dat_poses", "dat_vel_int", "dat_acc_int^2"]) plotter.plot(([p_timestamps, p_poses[:, 1, 3]], [p_timestamps, p_vel_int_poses[:, 1]], [p_imu_timestamps, p_accel_int_int[:, 1]],), "t [s]", "Z [m]", "Y Plot From Zero", labels=["dat_poses", "dat_vel_int", "dat_acc_int^2"]) plotter.plot(([p_timestamps, p_poses[:, 2, 3]], [p_timestamps, p_vel_int_poses[:, 2]], [p_imu_timestamps, p_accel_int_int[:, 2]],), "t [s]", "Z [m]", "Z Plot From Zero", labels=["dat_poses", "dat_vel_int", "dat_acc_int^2"]) # plot trajectory rotated wrt to the first frame p_poses_from_I = np.array([np.linalg.inv(p_poses[0]).dot(p) for p in p_poses]) plotter.plot(([p_poses_from_I[:, 0, 3], p_poses_from_I[:, 1, 3]],), "x [m]", "Y [m]", "XY Plot From Identity", equal_axes=True) plotter.plot(([p_poses_from_I[:, 0, 3], p_poses_from_I[:, 2, 3]],), "X [m]", "Z [m]", "XZ Plot From Identity", equal_axes=True) plotter.plot(([p_poses_from_I[:, 1, 3], p_poses_from_I[:, 2, 3]],), "Y [m]", "Z [m]", "YZ Plot From Identity", equal_axes=True) # plot p_velocities plotter.plot(([p_timestamps, p_velocities[:, 0]], [p_timestamps, p_velocities[:, 1]], [p_timestamps, p_velocities[:, 2]]), "t [s]", "v [m/s]", "YZ Plot", labels=["dat_vx", "dat_vy", "dat_vz"]) # make sure the interpolated acceleration and gyroscope measurements are the same plotter.plot(([p_imu_timestamps, p_gyro_measurements[:, 0]], [imu_timestamps, imu_data[:, wx]],), "t [s]", "w [rad/s]", "Rot Vel X Verification") plotter.plot(([p_imu_timestamps, p_gyro_measurements[:, 1]], [imu_timestamps, imu_data[:, wy]],), "t [s]", "w [rad/s]", "Rot Vel Y Verification") plotter.plot(([p_imu_timestamps, p_gyro_measurements[:, 2]], [imu_timestamps, imu_data[:, wz]],), "t [s]", "w [rad/s]", "Rot Vel Z Verification") plotter.plot(([p_imu_timestamps, p_accel_measurements[:, 0]], [imu_timestamps, imu_data[:, ax]],), "t [s]", "a [m/s^2]", "Accel X Verification") plotter.plot(([p_imu_timestamps, p_accel_measurements[:, 1]], [imu_timestamps, imu_data[:, ay]],), "t [s]", "a [m/s^2]", "Accel Y Verification") plotter.plot(([p_imu_timestamps, p_accel_measurements[:, 2]], [imu_timestamps, imu_data[:, az]],), "t [s]", "a [m/s^2]", "Accel Z Verification") # integrate gyroscope to compare against rotation p_gyro_int = [data["T_i_vk"][0][:3, :3]] for i in range(0, len(p_imu_timestamps) - 1): dt = p_imu_timestamps[i + 1] - p_imu_timestamps[i] p_gyro_int.append(p_gyro_int[-1].dot(exp_SO3(dt * p_gyro_measurements[i]))) p_gyro_int = np.array([log_SO3(o) for o in p_gyro_int]) p_orientation = np.array([log_SO3(p[:3, :3]) for p in data["T_i_vk"]]) plotter.plot(([p_imu_timestamps, np.unwrap(p_gyro_int[:, 0])], [p_timestamps, np.unwrap(p_orientation[:, 0])],), "t [s]", "rot [rad/s]", "Theta X Cmp Plot", labels=["gyro_int", "dat_pose"]) plotter.plot(([p_imu_timestamps, np.unwrap(p_gyro_int[:, 1])], [p_timestamps, np.unwrap(p_orientation[:, 1])],), "t [s]", "rot [rad/s]", "Theta Y Cmp Plot", labels=["gyro_int", "dat_pose"]) plotter.plot(([p_imu_timestamps, np.unwrap(p_gyro_int[:, 2])], [p_timestamps, np.unwrap(p_orientation[:, 2])],), "t [s]", "rot [rad/s]", "Theta Z Cmp Plot", labels=["gyro_int", "dat_pose"]) vel_from_gps_rel_poses = [] for k in range(0, len(gps_poses) - 1): dt = imu_timestamps[k + 1] - imu_timestamps[k] T_i_vk = gps_poses[k] T_i_vkp1 = gps_poses[k + 1] T_vk_vkp1 = np.linalg.inv(T_i_vk).dot(T_i_vkp1) vel_from_gps_rel_poses.append(T_vk_vkp1[0:3, 3] / dt) # vel_from_gps_rel_poses.append(log_SE3(T_vk_vkp1)[0:3] / dt) vel_from_gps_rel_poses = np.array(vel_from_gps_rel_poses) plotter.plot(([imu_timestamps[1:], vel_from_gps_rel_poses[:, 0]], [p_timestamps, p_velocities[:, 0]], [p_imu_timestamps, p_accel_int[:, 0]],), "t [s]", "v [m/s]", "Velocity X Cmp Plot", labels=["gps_rel", "dat_vel", "dat_accel_int"]) plotter.plot(([imu_timestamps[1:], vel_from_gps_rel_poses[:, 1]], [p_timestamps, p_velocities[:, 1]], [p_imu_timestamps, p_accel_int[:, 1]],), "t [s]", "v [m/s]", "Velocity Y Cmp Plot", labels=["gps_rel", "dat_vel", "dat_accel_int"]) plotter.plot(([imu_timestamps[1:], vel_from_gps_rel_poses[:, 2]], [p_timestamps, p_velocities[:, 2]], [p_imu_timestamps, p_accel_int[:, 2]],), "t [s]", "v [m/s]", "Velocity Z Cmp Plot", labels=["gps_rel", "dat_vel", "dat_accel_int"]) logger.print("Generating figures took %.2fs" % (time.time() - start_time)) logger.print("All done!")
def main(): args = parse_args() if not os.path.exists(config.OUTPUT_DIR): os.makedirs(config.OUTPUT_DIR) print(config) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = AnoVAEGAN(config.MODEL.IN_CHANNELS, config.MODEL.N_LAYERS, config.MODEL.N_FEATURES) discriminator = Discriminator(config.MODEL.IN_CHANNELS, config.DATASET.INPUT_SIZE, config.MODEL.N_LAYERS, config.MODEL.N_FEATURES) model.to(device) discriminator.to(device) model.train() discriminator.train() print(model) dataset = MNISTDataset(data_dir=config.DATASET.DIR, split='train', input_size=config.DATASET.INPUT_SIZE, transforms=[("HorizontalFlip", None), ("Rotation", { "degrees": 30 })]) train_loader = DataLoader(dataset, batch_size=config.OPTIM.BATCH_SIZE, shuffle=True) model_opt, model_scheduler = build_opt(config, model, len(train_loader)) disc_opt, disc_scheduler = build_opt(config, discriminator, len(train_loader), discriminator=True) rec_loss = build_loss(config.LOSS.REC_LOSS) prior_loss = build_loss(config.LOSS.PRIOR_LOSS) adv_loss = build_loss(config.LOSS.ADV_LOSS) rec_key = rec_loss_map[config.LOSS.REC_LOSS] loss_meter = ExpAvgMeter(0.98) rec_loss_meter = ExpAvgMeter(0.98) prior_loss_meter = ExpAvgMeter(0.98) adv_loss_meter = ExpAvgMeter(0.98) if config.VISDOM: plotter = Plotter( log_to_filename=os.path.join(config.OUTPUT_DIR, "logs.viz")) step = 0 for e in range(config.OPTIM.EPOCH): pbar = tqdm(train_loader) for img in pbar: step += 1 img = img.to(device) valid = torch.full((img.shape[0], 1), 1, dtype=torch.float, device=device) fake = torch.full((img.shape[0], 1), 0, dtype=torch.float, device=device) out = model(img) rec_l = rec_loss(out[rec_key], img) prior_l = prior_loss(out['mu'], out['logvar']) adv_l = adv_loss(discriminator(out['rec']), valid) loss = config.LOSS.REC_LOSS_COEFF * rec_l \ + config.LOSS.PRIOR_LOSS_COEFF * prior_l \ + config.LOSS.ADV_LOSS_COEFF * adv_l model_opt.zero_grad() loss.backward() model_opt.step() if model_scheduler.update_on_step: model_scheduler.step() real_loss = adv_loss(discriminator(img), valid) fake_loss = adv_loss(discriminator(out['rec'].detach()), fake) disc_loss = 0.5 * (real_loss + fake_loss) disc_opt.zero_grad() disc_loss.backward() disc_opt.step() if disc_scheduler.update_on_step: disc_scheduler.step() loss_meter.update(float(loss.data)) rec_loss_meter.update(float(rec_l.data)) prior_loss_meter.update(float(prior_l.data)) adv_loss_meter.update(float(adv_l.data)) pbar.set_description( 'Train Epoch : {0}/{1} Loss : {2:.4f} '.format( e + 1, config.OPTIM.EPOCH, loss_meter.value)) if config.VISDOM and step % config.PLOT_EVERY == 0: plotter.plot("Loss", step, loss_meter.value, "Loss", "Step", "Value") plotter.plot("Loss", step, rec_loss_meter.value, "Rec loss", "Step", "Value") plotter.plot("Loss", step, prior_loss_meter.value, "Prior loss", "Step", "Value") plotter.plot("Loss", step, adv_loss_meter.value, "Adv loss", "Step", "Value") model_lr = model_opt.param_groups[0]['lr'] disc_lr = disc_opt.param_groups[0]['lr'] plotter.plot("LR", step, model_lr, "Model LR", "Step", "Value") plotter.plot("LR", step, disc_lr, "Discr LR", "Step", "Value") if config.DEBUG.USE and step % config.DEBUG.DEBUG_EVERY == 0: save_batch_output(img, out['rec'], config.DEBUG.DETECT_THRESH, config.DEBUG.SAVE_SIZE, config.OUTPUT_DIR, 'batch_{}'.format(step)) if not model_scheduler.update_on_step: model_scheduler.step() if not disc_scheduler.update_on_step: disc_scheduler.step() save_path = os.path.join(config.OUTPUT_DIR, config.EXP_NAME + "_checkpoint.pth") torch.save({'cfg': config, 'params': model.state_dict()}, save_path)
ax.axis((x0 - plot_margin, x1 + plot_margin, y0 - plot_margin, y1 + plot_margin)) #K08 if seq == "K08": ax.axis((0, 1000, -600, 400)) plotter.plot(( [imu_only_poses[:, 0, 3], imu_only_poses[:, 1, 3]], [vision_only_poses[:, 0, 3], vision_only_poses[:, 1, 3]], [msf_fusion_poses[:, 0, 3], msf_fusion_poses[:, 1, 3]], [gt_poses[:, 0, 3], gt_poses[:, 1, 3]], [vanilla_poses[:, 0, 3], vanilla_poses[:, 1, 3]], ), "x [m]", "y [m]", None, labels=["IMU", "vision", "ORB+MSF", "ground truth", "proposed"], colors=["turquoise", "gold", "green", "blue", "red"], equal_axes=True, filename=seq + ".svg", callback=plot_callback) # plotter.plot(([gt_poses[:, 0, 3], gt_poses[:, 1, 3]], # [hybrid_poses[:, 0, 3], hybrid_poses[:, 1, 3]], # ), # "x [m]", "y [m]", "KITTI Sequence %s" % seq[1:], # labels=["gt", "hybrid"], # equal_axes=True, filename=seq+"_one")
# ax.axis((0,1000,-600,400)) plotter.plot(( [ vision_only_traj_aligned.positions_xyz[:, 0], vision_only_traj_aligned.positions_xyz[:, 1] ], [ vins_mono_traj_aligned.positions_xyz[:, 0], vins_mono_traj_aligned.positions_xyz[:, 1] ], [ gt_traj_synced_vanilla.positions_xyz[:, 0], gt_traj_synced_vanilla.positions_xyz[:, 1] ], [ vanilla_traj_aligned.positions_xyz[:, 0], vanilla_traj_aligned.positions_xyz[:, 1] ], ), "x [m]", "y [m]", None, labels=["vision", "vins_mono", "gt", "proposed"], equal_axes=True, filename=seq + ".svg", callback=plot_callback) # plotter.plot(([gt_poses[:, 0, 3], gt_poses[:, 1, 3]], # [hybrid_poses[:, 0, 3], hybrid_poses[:, 1, 3]],