class Data(): ''' A helper class to load Ju Yao data. ''' def __init__(self, hparam): self.hparam = hparam cprint('get info from csv file.', 'green') self._set_paramters_from_csv_table() if os.path.isdir(self.hparam.CT_RTStruct_dir): cprint('get organ 3D bool index from dicom RTStruct.', 'green') self._get_organ_3D_index() else: cprint('DICOM fold not found.', 'green') if 'deposition_file' in hparam: self.deposition = self.get_depositionMatrix() self.max_ray_idx = self.deposition.shape[1] cprint('deposition loaded.', 'green') else: assert('max_ray_idx' in hparam) self.max_ray_idx = hparam.max_ray_idx cprint('parsing valid ray data.', 'green') self.dict_rayBoolMat, self.dict_rayIdxMat, self.num_beams = self._read_rayIdx_mat() self._set_beamID_rayBeginNum_dict() def _set_beamID_rayBeginNum_dict(self): self.dict_beamID_ValidRayBeginNum = OrderedBunch() begin = 0 for beam_id, mask in self.dict_rayBoolMat.items(): self.dict_beamID_ValidRayBeginNum[beam_id] = [begin, mask.sum()] begin += mask.sum() num_bixel = 0 for beam_id, (_, num) in self.dict_beamID_ValidRayBeginNum.items(): num_bixel += num assert num_bixel == self.deposition.shape[1] def get_depositionMatrix(self): if not os.path.isdir(self.hparam.deposition_pickle_file_path): os.makedirs(self.hparam.deposition_pickle_file_path) fn_depos = os.path.join(self.hparam.deposition_pickle_file_path, 'deposition.pickle') if os.path.isfile(fn_depos): cprint('loading deposition data.', 'green') D = unpickle_object(fn_depos) else: cprint('building deposition data from Deposition_Index.txt', 'green') D = self._read_deposition_file_and_save_to_pickle_sparse() # use sparse matrix to store D # check shape ptsNum = 0 for organName, v in self.organ_info.items(): ptsNum += v['Points Number'] assert ptsNum == D.shape[0], f'shape not match: ptsNum={ptsNum}, deposition_matrix shape={D.shape}' print(f'deposition_matrix shape={D.shape}') return D def _get_uniqueRays_doseGridNum(self): print('check ray_index and organs in Deposition_Index.txt') ray_list, organName_Dep, doseGridNum = [], [], 0 is_redundancy = False for line in open(self.hparam.deposition_file, 'r'): if 'pts_num' in line: # new organ pts_num = int(line.split(':')[-1]) if pts_num == 0: cprint(f'skip organ with zero points, points_num:{pts_num} in Deposition_Index.txt', 'yellow') else: organ_name = self.get_organName_from_pointNum(pts_num) if organ_name in organName_Dep: is_redundancy = True cprint(f'skip duplicate points_num:{pts_num} and organ_name:{organ_name} in deposition.txt', 'yellow') else: is_redundancy = False organName_Dep.append(organ_name) cprint(f'find points_num:{pts_num} and organ_name:{organ_name} in deposition.txt', 'green') elif 'Indx:' in line: ray_list.append(int(line.split(' ')[0].split(':')[-1])) elif not is_redundancy and ' ]:' in line: doseGridNum += 1 # ensure deposition_Index.txt has all organs in csv and organs should be consistent with the deposition.txt try: if len(self.organ_info.keys()) != len(organName_Dep): raise ValueError('organ numbers in Dep and CSV not match') for organName_CSV, organName_D in zip(self.organ_info.keys(), organName_Dep): if organName_CSV != organName_D: raise ValueError('organ order in Dep and CSV not match') except Exception as e: print(e) pdb.set_trace() # deposition.txt seems lacking organs in cvs. # unique ray ray_list = list(set(ray_list)) # ensure ray_list shoud == [0, 1, 2, 3, ...] if self.hparam.is_check_ray_idx_order: for idx in range(len(ray_list)): assert idx == ray_list[idx], ray_list[idx] else: cprint('ray idx order is NOT checked. Some ray idx (should in integer order) may not present in Deposition_Index.txt', 'red') return ray_list, doseGridNum def _read_deposition_file_and_save_to_pickle_sparse(self): ''' using sparse matrix for deposition matrix to save memory. ''' # get shape inf to build dps matirx ray_list, DoseGridNum = self._get_uniqueRays_doseGridNum() # fill dps matrix print('building depostion matrix') row_idxs, col_idxs, values = [], [], [] with open(self.hparam.deposition_file, "r") as f: organ_order, point_idx = [], -1 # organ_order should be consistent with the deposition.txt is_redundancy = False for line in f: if 'pts_num' in line: # new organ pts_num = float(line.split(':')[-1]) if pts_num != 0: organ_name = self.get_organName_from_pointNum(pts_num) if organ_name in organ_order: is_redundancy = True else: is_redundancy = False organ_order.append(organ_name) elif not is_redundancy and ' ]:' in line: # new organ point point_idx += 1 elif not is_redundancy and 'Indx' in line: # new ray/bixel ray_idx = int(line.split(' ')[0].split(':')[-1]) value = float(line.split('Pt_dose:')[-1].split('(')[0]) row_idxs.append(point_idx) col_idxs.append(ray_idx) values.append(value) # build sparse deposition matrix D = coo_matrix((values, (row_idxs, col_idxs)), shape=(DoseGridNum, max(ray_list)+1)) cprint(f'sparse depostion matrix shape: {D.shape}', 'green') # save pickle_object(os.path.join(self.hparam.deposition_pickle_file_path, 'deposition.pickle'), D) return D def _set_paramters_from_csv_table(self): df = pd.read_csv(self.hparam.csv_file, skiprows=1, index_col=0, skip_blank_lines=True) # duplicated column will be renamed automatically # drop nan columns cols = [c for c in df.columns if 'Unnamed' not in c] df = df[cols] # drop organ with 0 point num organ_names = [] for name, pointNum in df.loc['Points Number'].items(): if pointNum == '0': organ_names.append(name) df = df.drop(organ_names, axis='columns') # drop another organs if skin present is_skin = False nonskin_names, skin_names = [], [] for name in df.columns: if 'skin' in name: is_skin = True skin_names.append(name) else: nonskin_names.append(name) if is_skin: self.csv_loss_table = df.drop(skin_names, axis='columns') # this var will be used in loss.py, therefore we should keep the duplicated columns df = df.drop(nonskin_names, axis='columns') # drop duplicated columns df = df.loc[:, ~df.columns.str.replace("(\.\d+)$", "").duplicated()] # set up dict of organ info self.organ_info = OrderedBunch(df.loc[['Grid Size', 'Points Number']].astype(float).to_dict()) for organ_name, v in self.organ_info.copy().items(): self.organ_info[organ_name]['Grid Size'] = v['Grid Size']*10. # cm to mm self.organ_info[organ_name]['Points Number'] = int(v['Points Number']) cprint('following csv info will be used to parsing deposition matrix', 'green') pp.pprint(dict(self.organ_info)) tmp = self.csv_loss_table.loc[['Grid Size', 'Points Number', 'Hard/Soft', 'Constraint Type', 'Min Dose', 'Max Dose', 'DVH Volume', 'Priority']] cprint('following csv info will be used in loss function', 'green') with pd.option_context('display.max_rows', None, 'display.max_columns', None): print(self.csv_loss_table.head(10)) def get_organName_from_pointNum(self, pointsNum): for organName, v in self.organ_info.items(): if v['Points Number'] == pointsNum: return organName raise ValueError(f'Can not find organ name with pointNum={pointsNum}') def get_pointNum_from_organName(self, organ_name): if organ_name not in self.organ_info: raise ValueError(f'Can not find organ name in OrganInfo.csv') return self.organ_info[organ_name]['Points Number'] def _read_rayIdx_mat(self): # get bool matrixes, where 1 indicates the present of ray with open(self.hparam.valid_ray_file, "r") as f: dict_rayBoolMat = collections.OrderedDict() beam_id = 0 for line in f: if 'F' in line: # new beam beam_id = int(line.replace('F','')) + 1 # NOTE: index of beam start from 1 dict_rayBoolMat[beam_id] = [] else: row = np.loadtxt(StringIO(line)) dict_rayBoolMat[beam_id].append(row) num_beams = beam_id # convert (list of 1D arrays) to (2D matrix) ray_num = 0 for beam_id, FM in dict_rayBoolMat.copy().items(): FM = np.asarray(FM, dtype=np.bool) dict_rayBoolMat[beam_id] = FM ray_num += FM.sum() assert ray_num == self.max_ray_idx assert ray_num == self.deposition.shape[1], f'shape not match: rayNum={ray_num}, deposition_matrix shape={D.shape}' # convert 1 to ray idx dict_rayIdxMat = collections.OrderedDict() ray_idx = -1 for beam_id, FM in dict_rayBoolMat.items(): idx_matrix = np.full_like(FM, self.max_ray_idx, dtype=np.int) # using max_ray_idx to indicate non-valid ray for row in range(FM.shape[0]): for col in range(FM.shape[1]): if FM[row, col] == 1: ray_idx += 1 idx_matrix[row, col] = ray_idx dict_rayIdxMat[beam_id] = idx_matrix return dict_rayBoolMat, dict_rayIdxMat, num_beams def project_to_fluenceMaps(self, fluenceVector): '''Convert 1D fluenceVector to 2D fluenceMap Arguments: fluenceVector: ndarray (#bixels, ) Return: {beam_id: fluenceMap ndarray (H,W)} ''' # set up a tmp with shape:(#bixels+1, ) and tmp[#bixels+1]=0; # where #bixels+1 indicate non-valid ray. # In this way, we can set the intensity of nonvalid ray to 0. tmp = np.append(fluenceVector, 0) dict_FluenceMap = collections.OrderedDict() # construct 2D fluence matrix from fluenceVector using numpy's fancy 2D indice for beam_id, ray_idx in self.dict_rayIdxMat.items(): dict_FluenceMap[beam_id] = tmp[ray_idx] return dict_FluenceMap def project_to_fluenceMaps_torch(self, fluence): '''fluence: (#bixels, ) return: {beam_id: fluenceMap with the shape of (H,W)} ''' # set up a tmp with shape:(#bixels+1, ) and tmp[#bixels+1]=0; tmp = torch.cat([fluence, torch.tensor([0.,], dtype=torch.float32, device=fluence.device)]) # shape:(max_ray_idx, ); tmp[max_ray_idx]=0 dict_FluenceMap = collections.OrderedDict() for beam_id, ray_idx in self.dict_rayIdxMat.items(): dict_FluenceMap[beam_id] = tmp[ray_idx] return dict_FluenceMap def get_rays_from_fluences(self, dict_FluenceMat): '''dict_fluences: {beam_id: fluence matrix} return: valid_rays: (#valid_bixels,) ''' valid_rays = [] for (_, boolMat), (idx, F) in zip(self.dict_rayBoolMat.items(), dict_FluenceMat.items()): valid_rays.append(F[boolMat].flatten()) valid_rays = np.concatenate(valid_rays, axis=0) return valid_rays def project_to_validRays_torch(self, dict_fluences): ''' Convert flatten fluenceMap to valid fluenceVector Arguments: dict_fluences: {beam_id: fluence vector} Return: valid_rays: (#valid_bixels,) dict_fluenceMaps: {beam_id: fluence matrix} ''' dict_fluenceMaps = OrderedBunch() valid_rays = [] for (beam_id, msk), (_, fluence) in zip(self.dict_rayBoolMat.items(), dict_fluences.items()): msk = torch.tensor(msk, dtype=torch.bool, device=fluence.device) valid_rays.append(fluence.view(*msk.shape)[msk].flatten()) # select valid rays and back to 1d vector dict_fluenceMaps[beam_id] = fluence.detach() valid_rays = torch.cat(valid_rays, axis=0) return valid_rays, dict_fluenceMaps def _get_organ_3D_index(self): ''' Return: self.organ_masks {organ_name: bool mask (z=167, x=512, y=512)} ''' ## get organ priorities from csv file df = self.csv_loss_table # only consider min_dose and priority df = df.loc[['Min Dose','Priority']] # string to float df = df.astype(float) # add a row to indentify ptv/oar ptv_oar = [1 if 'TV' in name else 0 for name in df.columns] ptv_oar = np.array(ptv_oar).reshape(1,-1) names = [name for name in df.columns] df2 = pd.DataFrame(ptv_oar, index=['ptv/oar'], columns=names) df = df.append(df2) df = df.loc[:, ~df.columns.str.replace("(\.\d+)$", "").duplicated()] # remove deuplicated organ name # sort to identify the overlapped organs and write to dataset dir to verify sorted_df = df.sort_values(by=['Priority', 'ptv/oar', 'Min Dose'], axis='columns', ascending=False) sorted_df.to_csv(self.hparam.csv_file.replace('OrganInfo.csv', 'sorted_organs.csv')) cprint('following organ order will be used to parse RTStruct', 'green') print(sorted_df) ## get contour from dicom # ensure all organ_names in csv appeared in RTStruct Dicom_Reader = Dicom_to_Imagestack(get_images_mask=True, arg_max=True) # arg_max is important to get the right order for overlapped organs. Dicom_Reader.Make_Contour_From_directory(self.hparam.CT_RTStruct_dir) roi_names = [] is_rtstruct_complete = True for name in sorted_df.columns: if name not in Dicom_Reader.all_rois: cprint(f'Warning: {name} not in RTStruct! we simply skip it.', 'red') is_rtstruct_complete == False else: roi_names.append(name) cprint(f'number of organ: {len(roi_names)}', 'green') if not is_rtstruct_complete: raise ValueError('some organ not in RTStruct') # get contours Dicom_Reader.set_contour_names(roi_names) Dicom_Reader.Make_Contour_From_directory(self.hparam.CT_RTStruct_dir) # match MonteCarlo dose's shape if Dicom_Reader.mask.shape != self.hparam.MCDose_shape: cprint(f'\nresize contour {Dicom_Reader.mask.shape} to match MC shape {self.hparam.MCDose_shape}', 'yellow') Dicom_Reader.mask = resize(Dicom_Reader.mask, self.hparam.MCDose_shape, order=0, mode='constant', cval=0, clip=False, preserve_range=True, anti_aliasing=False).astype(np.uint8) # match network output shape if self.hparam.net_output_shape != '': if Dicom_Reader.mask.shape[0] != self.hparam.net_output_shape[0]: cprint(f'resize and crop contour {Dicom_Reader.mask.shape} to match network output shape {self.hparam.net_output_shape}', 'yellow') Dicom_Reader.mask = resize(Dicom_Reader.mask, (self.hparam.net_output_shape[0],)+Dicom_Reader.mask.shape[1:], \ order=0, mode='constant', cval=0, clip=False, preserve_range=True, anti_aliasing=False).astype(np.uint8) crop_top = int((self.hparam.MCDose_shape[1]-self.hparam.net_output_shape[1] + 1) * 0.5) crop_left = int((self.hparam.MCDose_shape[2]-self.hparam.net_output_shape[2] + 1) * 0.5) Dicom_Reader.mask = Dicom_Reader.mask[:, crop_top:crop_top+self.hparam.net_output_shape[1], crop_left:crop_left+self.hparam.net_output_shape[2]] cprint(f'shape of contour label volume = {Dicom_Reader.mask.shape}', 'green') cprint(f'max label in contour label volume = {Dicom_Reader.mask.max()}', 'green') # label mask -> bool mask self.organ_masks = OrderedBunch() for i in range(1, Dicom_Reader.mask.max()+1): # iter over contours tmp = np.zeros_like(Dicom_Reader.mask, dtype=np.bool) tmp[Dicom_Reader.mask==i] = True self.organ_masks[roi_names[i-1]] = tmp # we may use these var out the method self.CT = Dicom_Reader.dicom_handle self.Dicom_Reader = Dicom_Reader debug = False if debug: # show overlapped ct pdb.set_trace() os.environ['SITK_SHOW_COMMAND'] = '/home/congliu/Downloads/Slicer-4.10.2-linux-amd64/Slicer' dicom_handle = Dicom_Reader.dicom_handle #annotations_handle = sitk.GetImageFromArray(self.organ_masks['Brainstem+2mmPRV']) # annotations_handle = sitk.GetImageFromArray(self.organ_masks['Brainstem+2mmPRV', 'PTV1-nd2-nx2', 'PTV2']) # annotations_handle = sitk.GetImageFromArray(self.organ_masks['PTV1-nd2-nx2']) annotations_handle = sitk.GetImageFromArray(self.organ_masks['Parotid_L']) annotations_handle.CopyInformation(dicom_handle) overlay = sitk.LabelOverlay(dicom_handle, annotations_handle, 0.1) sitk.Show(overlay)
class MonteCarlo(): def __init__(self, hparam, data): self.hparam = hparam self.data = data self.nb_leafPairs = 51 # 51 leaf pairs self.x_spacing = 0.5 # cm self.nb_apertures = 1000 # we will generate this number random apertures self.nb_beams = data.num_beams self._get_leafBottomEdgePosition() self._get_leafInJawField( ) # get y axis leaf position from jaw_y1 ,jaw_y2 def get_random_apertures(self): ''' return: self.dict_randomApertures {beam_id: ndarray(nb_apertures, H, W)} ''' def get_random_shape(H, W): if np.random.randint(0, 2): img = random_shapes((H, W), max_shapes=3, multichannel=False, min_size=min(H, W) // 3, allow_overlap=True, intensity_range=(1, 1))[0] img = np.where(img == 255, 0, img) else: img = np.zeros((H, W), dtype=np.uint8) for i in range(len(img)): # for each row l, r = np.random.randint(0, W + 1, (2, )) if l == r: continue if l > r: l, r = r, l img[i, l:r] = 1 return img save_path = Path( hparam.patient_ID).joinpath('dataset/dict_randomApertures.pickle') if os.path.isfile(save_path): self.dict_randomApertures = unpickle_object(save_path) return self.dict_randomApertures = OrderedBunch() for beam_id in range(1, self.nb_beams + 1): # for each beam H, W = self.data.dict_rayBoolMat[beam_id].shape self.dict_randomApertures[beam_id] = np.zeros( (self.nb_apertures, H, W), np.uint8) # default closed apertures for i, apt in enumerate( self.dict_randomApertures[beam_id]): # for each apterture if i == 0: # skip first aperture for each beam to get a all-leaf-opened aperture self.dict_randomApertures[beam_id][i] = np.ones((H, W), np.uint8) else: self.dict_randomApertures[beam_id][i] = get_random_shape( H, W) pickle_object(save_path, self.dict_randomApertures) def _get_leafBottomEdgePosition(self): ''' the leaf coords is: jaw_y2(+) jaw_x1(-) jaw_x2(+) jaw_y1(-) Return: self.coords, list of 51 leaves' bottom edge positions ''' ## read FM_info file FM_info_template = os.path.join(self.hparam.winServer_MonteCarloDir, 'templates', 'FM_info.txt') with open(FM_info_template, 'r') as f: lines = f.readlines() ## 0. get the thickness of the 51 pair leaves is_thick_line = False thicks = [] leaf_num = 0 for line in lines: if 'MLC_LeafThickness' in line: is_thick_line = True continue if leaf_num == self.nb_leafPairs: break if is_thick_line: thicks.append(float(line.replace('\n', ''))) leaf_num += 1 #print(thicks) #print(sum(thicks)) #print(f'center leaf thickness: {thicks[25]}') ## 1. get edge bottom coord of leaves (51 pairs) coords = [] # leaves bottom edges # upper half leaves: total 25 edge bottom positions coord26thLeafUp = thicks[25] / 2. # 26-th leaf with its center at y=0 coords.append(coord26thLeafUp) # +1 position for i in range(24, 0, -1): # [24, 0], +24 positions coord26thLeafUp += thicks[i] coords.append(coord26thLeafUp) coords = coords[::-1] # lower half leaves: total 26 edge bottom positions coord26thLeafbot = -thicks[25] / 2. coords.append(coord26thLeafbot) # +1 position for i in range(26, self.nb_leafPairs): # [26, 50], +25 positions coord26thLeafbot -= thicks[i] coords.append(coord26thLeafbot) # round to 2 decimals self.coords = [round(c, 2) for c in coords] def _get_leafInJawField(self): ''' get y axis leaf positions by finding the leaves in jaw field Return: self.dict_jawsPos {beam_id: [x1,x2,y1,y2]}, self.dict_inJaw {beam_id: (51,)} ''' self.dict_jawsPos = OrderedBunch() # jaw positions self.dict_inJaw = OrderedBunch( ) # bool vector indicate leaves in jaw Filed ## get jaw positions from seg*.txt file seg_files = glob.glob( os.path.join(self.hparam.winServer_MonteCarloDir, 'templates', 'Seg_beamID*.txt')) seg_files.sort() # sort to be consistent with beam_id for beam_id, seg in enumerate(seg_files): beam_id += 1 H, W = self.data.dict_rayBoolMat[beam_id].shape # print(f'beam_ID:{beam_id}; file_name:{seg}') with open(seg, 'r') as f: lines = f.readlines() ## get jaw positions is_jaw_line = False jaw = OrderedBunch() for line in lines: if 'MU_CollimatorJawX1' in line: is_jaw_line = True continue if is_jaw_line: position = line.split(' ')[1:5] position = [float(p) for p in position] jaw.x1, jaw.x2, jaw.y1, jaw.y2 = position print(f'jaw position: {jaw.x1, jaw.x2, jaw.y1, jaw.y2}') break self.dict_jawsPos[beam_id] = jaw ## Is a leaf in jaws' open field? # for upper half leaves: if (leaf bottom edge > jaw_y1) {this leaf in valid field} # for lower half leaves: if (leaf upper edge < jaw_y2) {this leaf in valid field} self.dict_inJaw[beam_id] = np.empty((self.nb_leafPairs, ), dtype=np.bool) for i, c in enumerate(self.coords): in_field = False if (c < jaw.y2 and c > jaw.y1): in_field = True if (c < jaw.y2 and self.coords[i - 1] > jaw.y1): # consider upper edge in_field = True self.dict_inJaw[beam_id][i] = in_field # print(f'{in_field}---{i}: {c}') # print(f'{self.dict_inJaw[beam_id].sum()}') assert self.dict_inJaw[beam_id].sum( ) == H, f'H={H}, inJaw={self.dict_inJaw[beam_id].sum()}' def _get_x_axis_position(self): ''' get x axis position from self.dict_randomApertures Return: self.dict_lrs {beam_id: strings (#aperture, 51)}, NOTE: 51 leaf pairs in reversed order. self.nb_beams self.nb_apertures ''' self.dict_lrs = OrderedBunch() # {beam_id: (#aperture, H)} def get_leafPos_for_a_row(row): ''' [0.0] 0 [0.5] 0 [1.0] 1 [1.5] 1 [2.0] 0 [2.5] 0 [3.0] ''' jaw_x1 = self.dict_jawsPos[beam_id].x1 if (row == 0).all(): # closed row lr = default_lr first, last = 0, 0 else: # opened row first, last = np.nonzero(row)[0][[ 0, -1 ]] # get first 1 and last 1 positions # last += 1 # block the left bixel of first 1, and right bixel of last 1; TODO +1? l = jaw_x1 + first * self.x_spacing # spacing 0.5mm r = jaw_x1 + last * self.x_spacing # spacing 0.5mm lr = '{:.2f} {:.2f}\n'.format(l, r) # cprint(f'row:{row_idx}; {first} {last}; {lr}', 'green') return lr for beam_id, apts in self.dict_randomApertures.items( ): # 0. for each beam # print(f'\n beam_id:{beam_id}') H, W = self.data.dict_rayBoolMat[beam_id].shape # print(f'height:{H}; width:{W}') pos = self.dict_jawsPos[ beam_id].x1 - self.x_spacing # leaf closed at jaw_x1-0.5 by default default_lr = '{:.2f} {:.2f}\n'.format( pos, pos) # by default, leaves closed self.dict_lrs[beam_id] = np.full( (self.nb_apertures, self.nb_leafPairs), default_lr, dtype=object) # (#aperture, 51), for a in range(self.nb_apertures): # 1. for each aperture row_idx = 0 for i in range(self.nb_leafPairs): # 2. for each row if self.dict_inJaw[beam_id][i]: lr = get_leafPos_for_a_row(apts[a, row_idx]) self.dict_lrs[beam_id][a, i] = lr row_idx += 1 self.dict_lrs[beam_id][a] = self.dict_lrs[beam_id][ a, :: -1] # NOTE: In TPS, 51 leaf pairs are in reversed order. def write_to_seg_txt(self): """ Write seg*.txt to the shared disk of windowsServer Args: self.dict_lrs {beam_id: strings (#aperture, 51)}, NOTE: 51 leaf pairs in reversed order. self.nb_apertures self.nb_beams Outputs: seg*.txt """ ## write Seg_{beam_id}_{aperture_id}.txt for beam_id in range(1, self.nb_beams + 1): seg_template = os.path.join(self.hparam.winServer_MonteCarloDir, 'templates', f'Seg_beamID{beam_id}.txt') with open(seg_template, 'r') as f: lines = f.readlines() for aperture_id in range(0, self.nb_apertures): ap_lines = lines.copy() + [None] * 51 ap_lines[-51:] = self.dict_lrs[beam_id][ aperture_id] # 51 leaves positions # write Seg*.txt save_path = os.path.join(self.hparam.winServer_MonteCarloDir, 'Segs', f'Seg_{beam_id}_{aperture_id}.txt') with open(save_path, "w") as f: f.writelines(ap_lines) cprint(f'Writing Seg_{beam_id}_{aperture_id}.txt', 'green') cprint( f'Done. {self.nb_beams*self.nb_apertures} Seg*.txt files have been written to Dir {self.hparam.winServer_MonteCarloDir}/segs.', 'green') def get_unit_MCdose(self): ''' Return: unitMUDose, ndarray (nb_beams*nb_apertures, #slice, H, W) ''' self._get_x_axis_position( ) # get x axis position from the saved random generated fluences cprint( f'compute unit MU Dose on winServer and save results to {self.hparam.winServer_MonteCarloDir}', 'green') pdb.set_trace() if not Path(self.hparam.winServer_MonteCarloDir, 'Segs', 'Seg_6_999.txt').is_file(): self.write_to_seg_txt() call_FM_gDPM_on_windowsServer(self.hparam.patient_ID, self.nb_beams, self.nb_apertures, hparam.winServer_nb_threads) pdb.set_trace() def get_dose(self, uid): ''' return:mcDose(#slice, H, W) ''' dpm_result_dir = Path(self.hparam.winServer_MonteCarloDir, 'gDPM_results', f'dpm_result_{uid}Ave.dat') with open(dpm_result_dir, 'rb') as f: dose = np.fromfile(f, dtype=np.float32) dose = dose.reshape(*hparam.MCDose_shape) mcDose = np.swapaxes(dose, 2, 1) return mcDose
def init_segments(self): '''return: dict_gradMaps {beam_id: matrix} new_dict_segments {beam_id: vector}''' # deposition matrix (#voxels, #bixels) deposition = convert_depoMatrix_to_tensor(self.data.deposition, self.hparam.device) # get fluence if self.hparam.optimization_continue: # continue last optimization file_name = os.path.join( self.hparam.optimized_segments_MUs_file_path, 'optimized_segments_MUs.pickle') if not os.path.isfile(file_name): raise ValueError(f'file not exist: {file_name}') else: cprint(f'continue last optimization from {file_name}', 'yellow') segments_and_MUs = unpickle_object(file_name) dict_segments, dict_MUs = OrderedBunch(), OrderedBunch() for beam_id, seg_MU in segments_and_MUs.items(): dict_segments[beam_id] = torch.tensor( seg_MU['Seg'], dtype=torch.float32, device=self.hparam.device) dict_MUs[beam_id] = torch.tensor(seg_MU['MU'], dtype=torch.float32, device=self.hparam.device, requires_grad=True) fluence, _ = computer_fluence(self.data, dict_segments, dict_MUs) self.dict_segments = OrderedBunch() for beam_id, seg in dict_segments.items(): self.dict_segments[beam_id] = seg.cpu().numpy() else: fluence = torch.zeros((deposition.size(1), ), dtype=torch.float32, device=self.hparam.device, requires_grad=True) # (#bixels,) # compute fluence gradient doses = cal_dose(deposition, fluence) # cal dose (#voxels, ) dict_organ_doses = split_doses( doses, self.data.organName_ptsNum ) # split organ_doses to obtain individual organ doses loss, breaking_points_nums = self.loss.loss_func(dict_organ_doses) print(f'breaking points #: ', end='') for organ_name, breaking_points_num in breaking_points_nums.items(): print(f'{organ_name}: {breaking_points_num} ', end='') print(f'loss={to_np(loss)}\n\n') loss.backward(retain_graph=False) # backward to get grad grads = fluence.grad.detach().cpu().numpy() # (#bixels,) # project 1D grad vector (#vaild_bixels,) to 2D fluence maps {beam_id: matrix} dict_gradMaps = self.data.project_to_fluenceMaps( grads) # {beam_id: matrix} new_dict_segments, new_dict_lrs = self.sp.solve( dict_gradMaps) # new_dict_segments {beam_id: vector} del fluence, doses, deposition return dict_gradMaps, new_dict_segments, new_dict_lrs