def train(self): ### load settings ### config = self.config #TODO, fix this model = self.model # define loss #TODO, add more loss loss_config = config['loss'] if loss_config['name'] == 'Aux': criterion = MultiAuxillaryElementNLLLoss( 3, loss_config['loss_weight'], config['nclass']) else: print('do not support other loss yet') quit() # dataloader validation_config = config['validation'] loader_config = config['loader'] args_inference = lambda: None if validation_config['metric'] is not None: print('prepare the data ... ...') filenames = glob(loader_config['datafolder'] + '/*_GT.ome.tif') filenames.sort() total_num = len(filenames) LeaveOut = validation_config['leaveout'] if len(LeaveOut) == 1: if LeaveOut[0] > 0 and LeaveOut[0] < 1: num_train = int(np.floor((1 - LeaveOut[0]) * total_num)) shuffled_idx = np.arange(total_num) random.shuffle(shuffled_idx) train_idx = shuffled_idx[:num_train] valid_idx = shuffled_idx[num_train:] else: valid_idx = [int(LeaveOut[0])] train_idx = list( set(range(total_num)) - set(map(int, LeaveOut))) elif LeaveOut: valid_idx = list(map(int, LeaveOut)) train_idx = list(set(range(total_num)) - set(valid_idx)) valid_filenames = [] train_filenames = [] for fi, fn in enumerate(valid_idx): valid_filenames.append(filenames[fn][:-11]) for fi, fn in enumerate(train_idx): train_filenames.append(filenames[fn][:-11]) args_inference.size_in = config['size_in'] args_inference.size_out = config['size_out'] args_inference.OutputCh = validation_config['OutputCh'] args_inference.nclass = config['nclass'] else: #TODO, update here print('need validation') quit() if loader_config['name'] == 'default': from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0 as train_loader train_set_loader = DataLoader( train_loader(train_filenames, loader_config['PatchPerBuffer'], config['size_in'], config['size_out']), num_workers=loader_config['NumWorkers'], batch_size=loader_config['batch_size'], shuffle=True) elif loader_config['name'] == 'focus': from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0C as train_loader train_set_loader = DataLoader( train_loader(train_filenames, loader_config['PatchPerBuffer'], config['size_in'], config['size_out']), num_workers=loader_config['NumWorkers'], batch_size=loader_config['batch_size'], shuffle=True) else: print('other loader not support yet') quit() num_iterations = 0 num_epoch = 0 #TODO: load num_epoch from checkpoint start_epoch = num_epoch for _ in range(start_epoch, config['epochs'] + 1): # sets the model in training mode model.train() optimizer = None optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay']) # check if re-load on training data in needed if num_epoch > 0 and num_epoch % loader_config[ 'epoch_shuffle'] == 0: print('shuffling data') train_set_loader = None train_set_loader = DataLoader( train_loader(train_filenames, loader_config['PatchPerBuffer'], config['size_in'], config['size_out']), num_workers=loader_config['NumWorkers'], batch_size=loader_config['batch_size'], shuffle=True) # Training starts ... epoch_loss = [] for i, current_batch in tqdm(enumerate(train_set_loader)): inputs = Variable(current_batch[0].cuda()) targets = current_batch[1] outputs = model(inputs) if len(targets) > 1: for zidx in range(len(targets)): targets[zidx] = Variable(targets[zidx].cuda()) else: targets = Variable(targets[0].cuda()) optimizer.zero_grad() if len(current_batch) == 3: # input + target + cmap cmap = Variable(current_batch[2].cuda()) loss = criterion(outputs, targets, cmap) else: # input + target loss = criterion(outputs, targets) loss.backward() optimizer.step() epoch_loss.append(loss.data.item()) num_iterations += 1 average_training_loss = sum(epoch_loss) / len(epoch_loss) # validation if num_epoch % validation_config['validate_every_n_epoch'] == 0: validation_loss = np.zeros( (len(validation_config['OutputCh']) // 2, )) model.eval() for img_idx, fn in enumerate(valid_filenames): # target label = np.squeeze(imread(fn + '_GT.ome.tif')) label = np.expand_dims(label, axis=0) # input image input_img = np.squeeze(imread(fn + '.ome.tif')) if len(input_img.shape) == 3: # add channel dimension input_img = np.expand_dims(input_img, axis=0) elif len(input_img.shape) == 4: # assume number of channel < number of Z, make sure channel dim comes first if input_img.shape[0] > input_img.shape[1]: input_img = np.transpose(input_img, (1, 0, 2, 3)) # cmap tensor costmap = np.squeeze(imread(fn + '_CM.ome.tif')) # output outputs = model_inference(model, input_img, model.final_activation, args_inference) assert len( validation_config['OutputCh']) // 2 == len(outputs) for vi in range(len(outputs)): if label.shape[ 0] == 1: # the same label for all output validation_loss[vi] += compute_iou( outputs[vi][0, :, :, :] > 0.5, label[0, :, :, :] == validation_config['OutputCh'][2 * vi + 1], costmap) else: validation_loss[vi] += compute_iou( outputs[vi][0, :, :, :] > 0.5, label[vi, :, :, :] == validation_config['OutputCh'][2 * vi + 1], costmap) average_validation_loss = validation_loss / len( valid_filenames) print( f'Epoch: {num_epoch}, Training Loss: {average_training_loss}, Validation loss: {average_validation_loss}' ) else: print( f'Epoch: {num_epoch}, Training Loss: {average_training_loss}' ) if num_epoch % config['save_every_n_epoch'] == 0: save_checkpoint( { 'epoch': num_epoch, 'num_iterations': num_iterations, 'model_state_dict': model.state_dict(), #'best_val_score': self.best_val_score, 'optimizer_state_dict': optimizer.state_dict(), 'device': str(self.device), }, checkpoint_dir=config['checkpoint_dir'], logger=self.logger) num_epoch += 1
def train(args, model): model.train() # check logger if not args.TestMode and os.path.isfile(args.LoggerName): print('logger file exists') quit() text_file = open(args.LoggerName, 'a') print(f'Epoch,Training_Loss,Validation_Loss\n', file=text_file) text_file.close() # load the correct loss function if args.Loss == 'NLL_CM' and args.model == 'unet_2task': from aicsmlsegment.custom_loss import MultiTaskElementNLLLoss criterion = MultiTaskElementNLLLoss(args.LossWeight, args.nclass) print('use 2 task elementwise NLL loss') elif args.Loss == 'NLL_CM' and (args.model == 'unet_ds' or args.model == 'unet_xy' \ or args.model == 'unet_deeper_xy' or args.model == 'unet_xy_d6' \ or args.model == 'unet_xy_p3' or args.model == 'unet_xy_p2'): from aicsmlsegment.custom_loss import MultiAuxillaryElementNLLLoss criterion = MultiAuxillaryElementNLLLoss(3, args.LossWeight, args.nclass) print('use unet with deep supervision loss') elif args.Loss == 'NLL_CM' and args.model == 'unet_xy_multi_task': from aicsmlsegment.custom_loss import MultiTaskElementNLLLoss criterion = MultiTaskElementNLLLoss(args.LossWeight, args.nclass) print('use 2 task elementwise NLL loss') # prepare the training/validattion filenames print('prepare the data ... ...') filenames = glob.glob(args.DataPath + '/*_GT.ome.tif') filenames.sort() total_num = len(filenames) if len(args.LeaveOut) == 1: if args.LeaveOut[0] > 0 and args.LeaveOut[0] < 1: num_train = int(np.floor((1 - args.LeaveOut[0]) * total_num)) shuffled_idx = np.arange(total_num) random.shuffle(shuffled_idx) train_idx = shuffled_idx[:num_train] valid_idx = shuffled_idx[num_train:] else: valid_idx = [int(args.LeaveOut[0])] train_idx = list( set(range(total_num)) - set(map(int, args.LeaveOut))) elif args.LeaveOut: valid_idx = list(map(int, args.LeaveOut)) train_idx = list(set(range(total_num)) - set(valid_idx)) valid_filenames = [] train_filenames = [] for fi, fn in enumerate(valid_idx): valid_filenames.append(filenames[fn][:-11]) for fi, fn in enumerate(train_idx): train_filenames.append(filenames[fn][:-11]) # may need a different validation method #validation_set_loader = DataLoader(exp_Loader(validation_filenames), num_workers=1, batch_size=1, shuffle=False) if args.Augmentation == 'NOAUG_M': from aicsmlsegment.DataLoader3D.Universal_Loader import NOAUG_M as train_loader print('use no augmentation, with cost map') elif args.Augmentation == 'RR_FH_M': from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M as train_loader print('use flip + rotation augmentation, with cost map') elif args.Augmentation == 'RR_FH_M0': from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0 as train_loader print('use flip + rotation augmentation, with cost map') elif args.Augmentation == 'RR_FH_M0C': from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0C as train_loader print( 'use flip + rotation augmentation, with cost map, and also count valid pixels' ) # softmax for validation softmax = nn.Softmax(dim=1) softmax.cuda() for epoch in range(args.NumEpochs + 1): if epoch % args.EpochPerBuffer == 0: print('shuffling training data ... ...') random.shuffle(train_filenames) train_set_loader = DataLoader(train_loader(train_filenames, args.PatchPerBuffer, args.size_in, args.size_out), num_workers=args.NumWorkers, batch_size=args.BatchSize, shuffle=True) print('training data is ready') # specific optimizer for this epoch optimizer = None if len(args.lr) == 1: # single value optimizer = optim.Adam(model.parameters(), lr=args.lr[0], weight_decay=args.WeightDecay) elif len( args.lr ) > 1: # [stage_1, lr_1, stage_2, lr_2, ..., stage_k, lr_k, lr_final] assert len(args.lr) % 2 == 1 num_training_stage = (len(args.lr) - 1) // 2 elsecase = True for ts in range(num_training_stage): if epoch < args.lr[ts * 2]: optimizer = optim.Adam(model.parameters(), lr=args.lr[ts * 2 + 1], weight_decay=args.WeightDecay) elsecase = False break if elsecase: optimizer = optim.Adam(model.parameters(), lr=args.lr[-1], weight_decay=args.WeightDecay) assert optimizer is not None, f'optimzer setup fails' # re-open the logger file text_file = open(args.LoggerName, 'a') # Training starts ... epoch_loss = [] model.train() for step, current_batch in tqdm(enumerate(train_set_loader)): inputs = Variable(current_batch[0].cuda()) targets = current_batch[1] #print(inputs.size()) #print(targets[0].size()) outputs = model(inputs) #print(len(outputs)) #print(outputs[0].size()) if len(targets) > 1: for zidx in range(len(targets)): targets[zidx] = Variable(targets[zidx].cuda()) else: targets = Variable(targets[0].cuda()) optimizer.zero_grad() if len(current_batch) == 3: # input + target + cmap cmap = Variable(current_batch[2].cuda()) loss = criterion(outputs, targets, cmap) else: # input + target loss = criterion(outputs, targets) loss.backward() optimizer.step() epoch_loss.append(loss.data.item()) # Validation starts ... validation_loss = np.zeros((len(args.OutputCh) // 2, )) model.eval() for img_idx, fn in enumerate(valid_filenames): # target label_reader = AICSImage(fn + '_GT.ome.tif') #CZYX label = label_reader.data label = np.squeeze(label, axis=0) # 4-D after squeeze # when the tif has only 1 channel, the loaded array may have falsely swaped dimensions (ZCYX). we want CZYX # (This may also happen in different OS or different package versions) # ASSUMPTION: we have more z slices than the number of channels if label.shape[1] < label.shape[0]: label = np.transpose(label, (1, 0, 2, 3)) # input image input_reader = AICSImage(fn + '.ome.tif') #CZYX #TODO: check size input_img = input_reader.data input_img = np.squeeze(input_img, axis=0) if input_img.shape[1] < input_img.shape[0]: input_img = np.transpose(input_img, (1, 0, 2, 3)) # cmap tensor costmap_reader = AICSImage(fn + '_CM.ome.tif') # ZYX costmap = costmap_reader.data costmap = np.squeeze(costmap, axis=0) if costmap.shape[0] == 1: costmap = np.squeeze(costmap, axis=0) elif costmap.shape[1] == 1: costmap = np.squeeze(costmap, axis=1) # output outputs = model_inference(model, input_img, softmax, args) assert len(args.OutputCh) // 2 == len(outputs) for vi in range(len(outputs)): if label.shape[0] == 1: # the same label for all output validation_loss[vi] += compute_iou( outputs[vi][0, :, :, :] > 0.5, label[0, :, :, :] == args.OutputCh[2 * vi + 1], costmap) else: validation_loss[vi] += compute_iou( outputs[vi][0, :, :, :] > 0.5, label[vi, :, :, :] == args.OutputCh[2 * vi + 1], costmap) # print loss average_training_loss = sum(epoch_loss) / len(epoch_loss) average_validation_loss = validation_loss / len(valid_filenames) print( f'Epoch: {epoch}, Training Loss: {average_training_loss}, Validation loss: {average_validation_loss}' ) print(f'{epoch},{average_training_loss},{average_validation_loss}\n', file=text_file) text_file.close() # save the model if args.SaveEveryKEpoch > 0 and epoch % args.SaveEveryKEpoch == 0: filename = f'{args.model}-{epoch:03}-{args.model_tag}.pth' torch.save(model.state_dict(), args.ModelDir + os.sep + filename) print(f'save at epoch: {epoch})')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', required=True) args = parser.parse_args() config = load_config(args.config) # declare the model model = build_model(config) # load the trained model instance model_path = config['model_path'] print(f'Loading model from {model_path}...') load_checkpoint(model_path, model) # extract the parameters for preparing the input image args_norm = lambda: None args_norm.Normalization = config['Normalization'] # extract the parameters for running the model inference args_inference = lambda: None args_inference.size_in = config['size_in'] args_inference.size_out = config['size_out'] args_inference.OutputCh = config['OutputCh'] args_inference.nclass = config['nclass'] # run inf_config = config['mode'] if inf_config['name'] == 'file': fn = inf_config['InputFile'] data_reader = AICSImage(fn) img0 = data_reader.data if inf_config['timelapse']: assert img0.shape[0] > 1 for tt in range(img0.shape[0]): # Assume: dimensions = TCZYX img = img0[tt, config['InputCh'], :, :, :].astype(float) img = input_normalization(img, args_norm) if len(config['ResizeRatio']) > 0: img = resize( img, (1, config['ResizeRatio'][0], config['ResizeRatio'][1], config['ResizeRatio'][2]), method='cubic') for ch_idx in range(img.shape[0]): struct_img = img[ch_idx, :, :, :] struct_img = (struct_img - struct_img.min()) / ( struct_img.max() - struct_img.min()) img[ch_idx, :, :, :] = struct_img # apply the model output_img = model_inference(model, img, model.final_activation, args_inference) # extract the result and write the output if len(config['OutputCh']) == 2: writer = omeTifWriter.OmeTifWriter( config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_T_' + f'{tt:03}' + '_struct_segmentation.tiff') out = output_img[0] out = (out - out.min()) / (out.max() - out.min()) if len(config['ResizeRatio']) > 0: out = resize(out, (1.0, 1 / config['ResizeRatio'][0], 1 / config['ResizeRatio'][1], 1 / config['ResizeRatio'][2]), method='cubic') out = out.astype(np.float32) if config['Threshold'] > 0: out = out > config['Threshold'] out = out.astype(np.uint8) out[out > 0] = 255 writer.save(out) else: for ch_idx in range(len(config['OutputCh']) // 2): writer = omeTifWriter.OmeTifWriter( config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_T_' + f'{tt:03}' + '_seg_' + str(config['OutputCh'][2 * ch_idx]) + '.tiff') out = output_img[ch_idx] out = (out - out.min()) / (out.max() - out.min()) if len(config['ResizeRatio']) > 0: out = resize(out, (1.0, 1 / config['ResizeRatio'][0], 1 / config['ResizeRatio'][1], 1 / config['ResizeRatio'][2]), method='cubic') out = out.astype(np.float32) if config['Threshold'] > 0: out = out > config['Threshold'] out = out.astype(np.uint8) out[out > 0] = 255 writer.save(out) else: img = img0[0, :, :, :, :].astype(float) print(f'processing one image of size {img.shape}') if img.shape[1] < img.shape[0]: img = np.transpose(img, (1, 0, 2, 3)) img = img[config['InputCh'], :, :, :] img = input_normalization(img, args_norm) if len(config['ResizeRatio']) > 0: img = resize( img, (1, config['ResizeRatio'][0], config['ResizeRatio'][1], config['ResizeRatio'][2]), method='cubic') for ch_idx in range(img.shape[0]): struct_img = img[ ch_idx, :, :, :] # note that struct_img is only a view of img, so changes made on struct_img also affects img struct_img = (struct_img - struct_img.min()) / ( struct_img.max() - struct_img.min()) img[ch_idx, :, :, :] = struct_img # apply the model output_img = model_inference(model, img, model.final_activation, args_inference) # extract the result and write the output if len(config['OutputCh']) == 2: out = output_img[0] out = (out - out.min()) / (out.max() - out.min()) if len(config['ResizeRatio']) > 0: out = resize(out, (1.0, 1 / config['ResizeRatio'][0], 1 / config['ResizeRatio'][1], 1 / config['ResizeRatio'][2]), method='cubic') out = out.astype(np.float32) print(out.shape) if config['Threshold'] > 0: out = out > config['Threshold'] out = out.astype(np.uint8) out[out > 0] = 255 writer = omeTifWriter.OmeTifWriter( config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_struct_segmentation.tiff') writer.save(out) else: for ch_idx in range(len(config['OutputCh']) // 2): out = output_img[ch_idx] out = (out - out.min()) / (out.max() - out.min()) if len(config['ResizeRatio']) > 0: out = resize(out, (1.0, 1 / config['ResizeRatio'][0], 1 / config['ResizeRatio'][1], 1 / config['ResizeRatio'][2]), method='cubic') out = out.astype(np.float32) if config['Threshold'] > 0: out = out > config['Threshold'] out = out.astype(np.uint8) out[out > 0] = 255 writer = omeTifWriter.OmeTifWriter( config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_seg_' + str(config['OutputCh'][2 * ch_idx]) + '.tiff') writer.save(out) print(f'Image {fn} has been segmented') elif inf_config['name'] == 'folder': from glob import glob filenames = glob(inf_config['InputDir'] + '/*' + inf_config['DataType']) filenames.sort() #print(filenames) for _, fn in enumerate(filenames): # load data data_reader = AICSImage(fn) img0 = data_reader.data img = img0[0, :, :, :, :].astype(float) if img.shape[1] < img.shape[0]: img = np.transpose(img, (1, 0, 2, 3)) img = img[config['InputCh'], :, :, :] img = input_normalization(img, args_norm) #img = image_normalization(img, config['Normalization']) if len(config['ResizeRatio']) > 0: img = resize( img, (1, config['ResizeRatio'][0], config['ResizeRatio'][1], config['ResizeRatio'][2]), method='cubic') for ch_idx in range(img.shape[0]): struct_img = img[ ch_idx, :, :, :] # note that struct_img is only a view of img, so changes made on struct_img also affects img struct_img = (struct_img - struct_img.min()) / ( struct_img.max() - struct_img.min()) img[ch_idx, :, :, :] = struct_img # apply the model output_img = model_inference(model, img, model.final_activation, args_inference) # extract the result and write the output if len(config['OutputCh']) == 2: writer = omeTifWriter.OmeTifWriter( config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_struct_segmentation.tiff') if config['Threshold'] < 0: out = output_img[0] out = (out - out.min()) / (out.max() - out.min()) print(out.shape) if len(config['ResizeRatio']) > 0: out = resize(out, (1.0, 1 / config['ResizeRatio'][0], 1 / config['ResizeRatio'][1], 1 / config['ResizeRatio'][2]), method='cubic') out = out.astype(np.float32) out = (out - out.min()) / (out.max() - out.min()) writer.save(out) else: out = remove_small_objects( output_img[0] > config['Threshold'], min_size=2, connectivity=1) out = out.astype(np.uint8) out[out > 0] = 255 writer.save(out) else: for ch_idx in range(len(config['OutputCh']) // 2): writer = omeTifWriter.OmeTifWriter( config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_seg_' + str(config['OutputCh'][2 * ch_idx]) + '.ome.tif') if config['Threshold'] < 0: out = output_img[ch_idx] out = (out - out.min()) / (out.max() - out.min()) writer.save(out.astype(np.float32)) else: out = output_img[ch_idx] > config['Threshold'] out = out.astype(np.uint8) out[out > 0] = 255 writer.save(out) print(f'Image {fn} has been segmented')
def evaluate(args, model): model.eval() softmax = nn.Softmax(dim=1) softmax.cuda() # check validity of parameters assert args.nchannel == len( args.InputCh ), f'number of input channel does not match input channel indices' if args.mode == 'eval': filenames = glob.glob(args.InputDir + '/*' + args.DataType) filenames.sort() for fi, fn in enumerate(filenames): print(fn) # load data struct_img = load_single_image(args, fn, time_flag=False) print(struct_img.shape) # apply the model output_img = apply_on_image(model, struct_img, softmax, args) #output_img = model_inference(model, struct_img, softmax, args) #print(len(output_img)) for ch_idx in range(len(args.OutputCh) // 2): write = omeTifWriter.OmeTifWriter( args.OutputDir + pathlib.PurePosixPath(fn).stem + '_seg_' + str(args.OutputCh[2 * ch_idx]) + '.ome.tif') if args.Threshold < 0: write.save(output_img[ch_idx].astype(float)) else: out = output_img[ch_idx] > args.Threshold out = out.astype(np.uint8) out[out > 0] = 255 write.save(out) print(f'Image {fn} has been segmented') elif args.mode == 'eval_file': fn = args.InputFile print(fn) data_reader = AICSImage(fn) img0 = data_reader.data if args.timelapse: assert data_reader.shape[0] > 1 for tt in range(data_reader.shape[0]): # Assume: TCZYX img = img0[tt, args.InputCh, :, :, :].astype(float) img = input_normalization(img, args) if len(args.ResizeRatio) > 0: img = resize(img, (1, args.ResizeRatio[0], args.ResizeRatio[1], args.ResizeRatio[2]), method='cubic') for ch_idx in range(img.shape[0]): struct_img = img[ ch_idx, :, :, :] # note that struct_img is only a view of img, so changes made on struct_img also affects img struct_img = (struct_img - struct_img.min()) / ( struct_img.max() - struct_img.min()) img[ch_idx, :, :, :] = struct_img # apply the model output_img = model_inference(model, img, softmax, args) for ch_idx in range(len(args.OutputCh) // 2): writer = omeTifWriter.OmeTifWriter( args.OutputDir + pathlib.PurePosixPath(fn).stem + '_T_' + f'{tt:03}' + '_seg_' + str(args.OutputCh[2 * ch_idx]) + '.ome.tif') if args.Threshold < 0: out = output_img[ch_idx].astype(float) out = resize( out, (1.0, 1 / args.ResizeRatio[0], 1 / args.ResizeRatio[1], 1 / args.ResizeRatio[2]), method='cubic') writer.save(out) else: out = output_img[ch_idx] > args.Threshold out = resize( out, (1.0, 1 / args.ResizeRatio[0], 1 / args.ResizeRatio[1], 1 / args.ResizeRatio[2]), method='nearest') out = out.astype(np.uint8) out[out > 0] = 255 writer.save(out) else: img = img0[0, :, :, :].astype(float) if img.shape[1] < img.shape[0]: img = np.transpose(img, (1, 0, 2, 3)) img = img[args.InputCh, :, :, :] img = input_normalization(img, args) if len(args.ResizeRatio) > 0: img = resize(img, (1, args.ResizeRatio[0], args.ResizeRatio[1], args.ResizeRatio[2]), method='cubic') for ch_idx in range(img.shape[0]): struct_img = img[ ch_idx, :, :, :] # note that struct_img is only a view of img, so changes made on struct_img also affects img struct_img = (struct_img - struct_img.min()) / ( struct_img.max() - struct_img.min()) img[ch_idx, :, :, :] = struct_img # apply the model output_img = model_inference(model, img, softmax, args) for ch_idx in range(len(args.OutputCh) // 2): writer = omeTifWriter.OmeTifWriter( args.OutputDir + pathlib.PurePosixPath(fn).stem + '_seg_' + str(args.OutputCh[2 * ch_idx]) + '.ome.tif') if args.Threshold < 0: writer.save(output_img[ch_idx].astype(float)) else: out = output_img[ch_idx] > args.Threshold out = out.astype(np.uint8) out[out > 0] = 255 writer.save(out) print(f'Image {fn} has been segmented')