def __init__(self, config): self.config_data = config['data'] self.config_model = config['model'] self.config_solver = config['solver'] self.max_steps = self.config_solver['max_steps'] self.save_steps = self.config_solver['save_steps'] self.log_steps = self.config_solver['log_steps'] self.resume_steps = self.config_solver['resume_steps'] self.dataset = get_dataset(self.config_data) self.exp_root = self.config_solver['exp_root'] if not os.path.exists(self.exp_root): os.makedirs(self.exp_root) self.ckpt_root = os.path.join(self.exp_root, 'models') if not os.path.exists(self.ckpt_root): os.makedirs(self.ckpt_root) self.log_root = os.path.join(self.exp_root, 'logs') if not os.path.exists(self.log_root): os.makedirs(self.log_root) self.model = get_model(self.config_model) # self.loss = BalancedLoss(self.config_solver) self.loss = get_loss(self.config_solver['loss']) self.optimizer = get_optimizer(self.config_solver) self.pr_meter = BinaryPRMeter() self.dice_meter = DiceMeter() self.loss_meter = tf.metrics.Mean() log_time = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S') self.logger = Logger(os.path.join(self.log_root, f'{log_time}.log')) self.input_shape = (self.config_data['kwargs']['batch_size'], *self.config_model['input_shape']) # pdb.set_trace() self.model.build(self.input_shape) self.model.summary(print_fn=self.logger.info)
class SolverTest_2d(object): def __init__(self, config): self.config_data = config['data'] self.config_model = config['model'] self.config_solver = config['solver'] self.test_steps = self.config_solver['test_steps'] self.export = self.config_solver['export_results'] self.dataset = get_dataset(self.config_data) self.exp_root = self.config_solver['exp_root'] if not os.path.exists(self.exp_root): os.makedirs(self.exp_root) self.ckpt_root = os.path.join(self.exp_root, 'models') if not os.path.exists(self.ckpt_root): os.makedirs(self.ckpt_root) self.log_root = os.path.join(self.exp_root, 'logs') if not os.path.exists(self.log_root): os.makedirs(self.log_root) self.result_root = os.path.join(self.exp_root, 'results') if not os.path.exists(self.result_root): os.makedirs(self.result_root) self.model = get_model(self.config_model) self.optimizer = get_optimizer(self.config_solver) self.pr_meter = BinaryPRMeter() self.dice_meter = DiceMeter() log_time = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S') self.logger = Logger(os.path.join(self.log_root, f'{log_time}.log')) self.input_shape = (self.config_data['kwargs']['batch_size'], *self.config_model['input_shape']) self.model.build(self.input_shape) self.model.summary(print_fn=self.logger.info) # @tf.function def chop_forward(self, input_tensor, gt): # pdb.set_trace() converted_tensor = input_tensor[0, ...] converted_tensor = tf.transpose(converted_tensor, (2, 0, 1, 3)) batch_size = self.input_shape[0] z_layers = input_tensor.shape[3] nb_batches = tf.math.ceil(z_layers / batch_size) logits = [] for z in range(nb_batches): begin = tf.cast(z * batch_size, tf.int32) end = tf.cast(tf.minimum((z + 1) * batch_size, z_layers), tf.int32) sliced_tensor = converted_tensor[begin:end, ...] sliced_tensor = (sliced_tensor - tf.reduce_mean(sliced_tensor) ) / tf.math.reduce_std(sliced_tensor) logit_sliced_tensor = self.model(sliced_tensor, training=False) logits.append(logit_sliced_tensor) logit_tensor = tf.concat(logits, axis=0) logit_tensor = tf.transpose(logit_tensor, (1, 2, 0, 3)) logit_tensor = tf.expand_dims(logit_tensor, 0) return logit_tensor def load_model(self, step): checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer) ckpt_path = os.path.join(self.ckpt_root, f'step_{step:d}') checkpoint.restore(ckpt_path).expect_partial() def run(self): self.logger.info('Testing started.') # start_tick = time.time() test_loader = self.dataset.get_loader(training=False) step = self.test_steps self.load_model(step) self.logger.info(f'[Step {step:d}] Model Loaded.') for test_idx, (image_tensor, label_tensor, affine_tensor) in enumerate(test_loader): # pdb.set_trace() logit_tensor = self.chop_forward(image_tensor, label_tensor) pred_tensor = tf.nn.softmax(logit_tensor)[..., 0:1] print( f'positive predictions:{tf.math.count_nonzero(pred_tensor > 0.5).numpy()}.' ) self.pr_meter.update_state(label_tensor, pred_tensor) self.dice_meter.update_state(label_tensor, pred_tensor) test_p, test_r = precision_recall(label_tensor, pred_tensor) test_dice = dice_coef(label_tensor, pred_tensor) if self.export: save_image_nii( image_tensor, affine_tensor, os.path.join(self.result_root, f'{step}_{test_idx + 1}_img.nii')) save_pred_nii( pred_tensor, affine_tensor, os.path.join(self.result_root, f'{step}_{test_idx + 1}_pred.nii')) save_label_nii( label_tensor, affine_tensor, os.path.join(self.result_root, f'{step}_{test_idx + 1}_gt.nii')) self.logger.info( f'[Test {test_idx + 1}] Precision = {test_p:.4f}, Recall = {test_r:.4f}, Dice = {test_dice:.4f}.' ) test_p, test_r = self.pr_meter.result() test_dice = self.dice_meter.result() self.logger.info( f'[Total Average] Precision = {test_p:.4f}, Recall = {test_r:.4f}, Dice = {test_dice:.4f}.' )
class SolverTrain(object): def __init__(self, config): self.config_data = config['data'] self.config_model = config['model'] self.config_solver = config['solver'] self.max_steps = self.config_solver['max_steps'] self.save_steps = self.config_solver['save_steps'] self.log_steps = self.config_solver['log_steps'] self.resume_steps = self.config_solver['resume_steps'] self.dataset = get_dataset(self.config_data) self.exp_root = self.config_solver['exp_root'] if not os.path.exists(self.exp_root): os.makedirs(self.exp_root) self.ckpt_root = os.path.join(self.exp_root, 'models') if not os.path.exists(self.ckpt_root): os.makedirs(self.ckpt_root) self.log_root = os.path.join(self.exp_root, 'logs') if not os.path.exists(self.log_root): os.makedirs(self.log_root) self.model = get_model(self.config_model) # self.loss = BalancedLoss(self.config_solver) self.loss = get_loss(self.config_solver['loss']) self.optimizer = get_optimizer(self.config_solver) self.pr_meter = BinaryPRMeter() self.dice_meter = DiceMeter() self.loss_meter = tf.metrics.Mean() log_time = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S') self.logger = Logger(os.path.join(self.log_root, f'{log_time}.log')) self.input_shape = (self.config_data['kwargs']['batch_size'], *self.config_model['input_shape']) # pdb.set_trace() self.model.build(self.input_shape) self.model.summary(print_fn=self.logger.info) @tf.function def train_step(self, image_tensor, label_tensor): with tf.GradientTape() as tape: logit_tensor = self.model(image_tensor, training=True) train_loss_value = self.loss(label_tensor, logit_tensor) grads = tape.gradient(train_loss_value, self.model.trainable_variables) self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) pred_tensor = tf.nn.softmax(logit_tensor)[:,:,:,:,0:1] # print(tf.math.count_nonzero(pred_tensor > 0.5)) return train_loss_value, pred_tensor def save_model(self, step): checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer) ckpt_path = os.path.join(self.ckpt_root, f'step_{step:d}') checkpoint.write(ckpt_path) def load_model(self, step): checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer) ckpt_path = os.path.join(self.ckpt_root, f'step_{step:d}') checkpoint.restore(ckpt_path).expect_partial() def run(self): self.logger.info('Training started.') start_tick = time.time() train_loader = self.dataset.get_loader(training=True) if self.resume_steps > 0: step = self.resume_steps self.load_model(step) self.logger.info(f'[Step {step}] Checkpoint resumed.') else: step = 0 self.logger.info('Train from scratch.') step += 1 for image_tensor, label_tensor in train_loader: image_tensor = (image_tensor - tf.reduce_mean(image_tensor)) / tf.math.reduce_std(image_tensor) train_loss_value, pred_tensor = self.train_step(image_tensor, label_tensor) self.dice_meter.update_state(label_tensor, pred_tensor) self.pr_meter.update_state(label_tensor, pred_tensor) self.loss_meter.update_state(train_loss_value) if step % self.log_steps == 0: train_p, train_r = self.pr_meter.result() train_dice = self.dice_meter.result() train_loss_value = self.loss_meter.result() self.pr_meter.reset_states() self.dice_meter.reset_states() self.loss_meter.reset_states() elapsed_time = time.time() - start_tick et = str(datetime.timedelta(seconds=elapsed_time))[:-7] cur_lr = self.optimizer._decayed_lr(tf.float32) self.logger.info(f'[{et} Step {step:d} / {self.max_steps}] Training loss = {train_loss_value:.6f}, lr = {cur_lr:.6f}, Precision = {train_p:.4f}, Recall = {train_r:.4f}, Dice = {train_dice:.4f}.') if self.save_steps > 0 and step % self.save_steps == 0: self.save_model(step) self.logger.info(f'[Step {step}] checkpoint saved.') if step >= self.max_steps: break step += 1 self.logger.info('Training finished.')
class SolverTest(object): def __init__(self, config): self.config_data = config['data'] self.config_model = config['model'] self.config_solver = config['solver'] self.test_steps = self.config_solver['test_steps'] self.export = self.config_solver['export_results'] self.hybrid = self.config_solver['hybrid'] self.dataset = get_dataset(self.config_data) self.exp_root = self.config_solver['exp_root'] if not os.path.exists(self.exp_root): os.makedirs(self.exp_root) self.ckpt_root = os.path.join(self.exp_root, 'models') if not os.path.exists(self.ckpt_root): os.makedirs(self.ckpt_root) self.log_root = os.path.join(self.exp_root, 'logs') if not os.path.exists(self.log_root): os.makedirs(self.log_root) self.result_root = os.path.join(self.exp_root, 'results') if not os.path.exists(self.result_root): os.makedirs(self.result_root) self.model = get_model(self.config_model) self.optimizer = get_optimizer(self.config_solver) self.pr_meter = BinaryPRMeter() self.dice_meter = DiceMeter() self.loss_meter = tf.metrics.Mean() log_time = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S') self.logger = Logger(os.path.join(self.log_root, f'{log_time}.log')) self.batch_size = self.config_data['kwargs']['batch_size'] self.input_shape = (self.batch_size, *self.config_model['input_shape']) self.model.build(self.input_shape) self.model.summary(print_fn=self.logger.info) def chop_forward(self, input_tensor): N, H, W, D, C = self.input_shape # z_layers = input_tensor.shape[3] pd = D // 2 padded_tensor = tf.pad(input_tensor, ((0, 0), (0, 0), (0, 0), (pd, pd), (0, 0)), 'symmetric') pz_layers = padded_tensor.shape[3] num_logits = np.zeros([1, H, W, pz_layers, 2]) den_logits = np.zeros([1, H, W, pz_layers, 2]) for d_idx in range(pz_layers): num_logits[:, :, :, d_idx:d_idx + D, :] += self.model( padded_tensor[:, :, :, d_idx:d_idx + D, :], training=False) den_logits[:, :, :, d_idx:d_idx + D, :] += 1.0 logit_tensor = num_logits[:, :, :, pd:-pd, :] / den_logits[:, :, :, pd:-pd, :] return logit_tensor def save_model(self, step): checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer) ckpt_path = os.path.join(self.ckpt_root, f'step_{step:d}') checkpoint.write(ckpt_path) def load_model(self, step): checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer) ckpt_path = os.path.join(self.ckpt_root, f'step_{step:d}') checkpoint.restore(ckpt_path).expect_partial() def run(self): self.logger.info('Testing started.') # start_tick = time.time() test_loader = self.dataset.get_loader(training=False) step = self.test_steps self.load_model(step) self.logger.info(f'[Step {step:d}] Model Loaded.') for test_idx, (image_tensor, label_tensor, affine_tensor) in enumerate(test_loader): input_tensor = (image_tensor - tf.reduce_mean(image_tensor) ) / tf.math.reduce_std(image_tensor) if self.hybrid: logit_tensor = self.chop_forward(input_tensor) else: inputs_list = decompose_vol2cube(tf.squeeze(input_tensor), self.batch_size, 64, 1, 4) logits_list = [ self.model(x, training=False) for x in inputs_list ] logit_tensor = compose_prob_cube2vol(logits_list, image_tensor.shape[1:4], self.batch_size, 64, 4, 2) pred_tensor = tf.nn.softmax(logit_tensor)[..., 0:1] self.pr_meter.update_state(label_tensor, pred_tensor) self.dice_meter.update_state(label_tensor, pred_tensor) test_p, test_r = precision_recall(label_tensor, pred_tensor) test_dice = dice_coef(label_tensor, pred_tensor) self.logger.info( f'[Test {test_idx + 1}] Precision = {test_p:.4f}, Recall = {test_r:.4f}, Dice = {test_dice:.4f}.' ) if self.export: save_image_nii( image_tensor, affine_tensor, os.path.join(self.result_root, f'{step}_{test_idx}_img.nii')) save_pred_nii( pred_tensor, affine_tensor, os.path.join(self.result_root, f'{step}_{test_idx}_pred.nii')) save_label_nii( label_tensor, affine_tensor, os.path.join(self.result_root, f'{step}_{test_idx}_gt.nii')) test_p, test_r = self.pr_meter.result() test_dice = self.dice_meter.result() self.logger.info( f'[Total Average] Precision = {test_p:.4f}, Recall = {test_r:.4f}, Dice = {test_dice:.4f}.' )