def mask_and_concat_data_along_batch_dim(data, k): """Keeps the elements in data which were produced before time index k. Concatenates each list in data along the batch dim after masking. Also returns data from the first segment not in the valid mask.""" # Extract the Index of the Last Data Segment data_times = np.cumsum([traj.k for traj in data['trajectory']]) valid_mask = (data_times <= k) data_last = {} last_data_idxs = np.where(np.logical_not(valid_mask))[0] # Take the first last_data_idx if len(last_data_idxs) > 0: last_data_idx = last_data_idxs[0] last_data_valid = True else: # Take the last element as it is not valid anyway last_data_idx = len(valid_mask) - 1 last_data_valid = False # Get the last segment data data_last['system_config'] = data['system_config'][last_data_idx] data_last['waypoint_config'] = data['waypoint_config'][last_data_idx] data_last['trajectory'] = data['trajectory'][last_data_idx] data_last['spline_trajectory'] = data['spline_trajectory'][ last_data_idx] data_last['planning_horizon_n1'] = [ data['planning_horizon'][last_data_idx] ] data_last['K_nkfd'] = data['K_nkfd'][last_data_idx] data_last['k_nkf1'] = data['k_nkf1'][last_data_idx] data_last['img_nmkd'] = data['img_nmkd'][last_data_idx] # Get the main planner data data['system_config'] = SystemConfig.concat_across_batch_dim( np.array(data['system_config'])[valid_mask]) data['waypoint_config'] = SystemConfig.concat_across_batch_dim( np.array(data['waypoint_config'])[valid_mask]) data['trajectory'] = Trajectory.concat_across_batch_dim( np.array(data['trajectory'])[valid_mask]) data['spline_trajectory'] = Trajectory.concat_across_batch_dim( np.array(data['spline_trajectory'])[valid_mask]) data['planning_horizon_n1'] = np.array( data['planning_horizon'])[valid_mask][:, None] data['K_nkfd'] = tf.boolean_mask(tf.concat(data['K_nkfd'], axis=0), valid_mask) data['k_nkf1'] = tf.boolean_mask(tf.concat(data['k_nkf1'], axis=0), valid_mask) data['img_nmkd'] = np.array(np.concatenate(data['img_nmkd'], axis=0))[valid_mask] return data, data_last, last_data_valid
def concat_data_across_binning_dim(self, data): """Concatenate across the binning dimension. It is asummed that data is a dictionary where each key maps to a list of tensors, Trajectory, or System Config objects. The concatenated results are stored in lists of length 1 for each key (i.e. only one bin).""" data['start_speeds'] = [tf.concat(data['start_speeds'], axis=0)] data['start_configs'] = [SystemConfig.concat_across_batch_dim(data['start_configs'])] data['waypt_configs'] = [SystemConfig.concat_across_batch_dim(data['waypt_configs'])] data['spline_trajectories'] = [Trajectory.concat_across_batch_dim(data['spline_trajectories'])] data['horizons'] = [tf.concat(data['horizons'], axis=0)] data['lqr_trajectories'] = [Trajectory.concat_across_batch_dim(data['lqr_trajectories'])] data['K_nkfd'] = [tf.concat(data['K_nkfd'], axis=0)] data['k_nkf1'] = [tf.concat(data['k_nkf1'], axis=0)] return data
def mask_and_concat_data_along_batch_dim(data, k): """Keeps the elements in data which were produced before time index k. Concatenates each list in data along the batch dim after masking.""" # Extract the Index of the Last Data Segment data_times = np.cumsum( [u_nk2.shape[1].value for u_nk2 in data['optimal_control_nk2']]) valid_mask = (data_times <= k) data_last = {} last_data_idxs = np.where(np.logical_not(valid_mask))[0] # Take the first last_data_idx if len(last_data_idxs) > 0: last_data_idx = last_data_idxs[0] last_data_valid = True else: # Take the last element as it is not valid anyway last_data_idx = len(valid_mask) - 1 last_data_valid = False # Get the last segment data data_last['system_config'] = data['system_config'][last_data_idx] data_last['optimal_control_nk2'] = data['optimal_control_nk2'][ last_data_idx] data_last['img_nmkd'] = data['img_nmkd'][last_data_idx] # Get the main planner data data['system_config'] = SystemConfig.concat_across_batch_dim( np.array(data['system_config'])[valid_mask]) data['optimal_control_nk2'] = tf.boolean_mask( tf.concat(data['optimal_control_nk2'], axis=0), valid_mask) data['img_nmkd'] = np.array(np.concatenate(data['img_nmkd'], axis=0))[valid_mask] return data, data_last, last_data_valid