예제 #1
0
class SolverTrain(object):

    def __init__(self, config):
        
        self.config_data = config['data']
        self.config_model = config['model']
        self.config_solver = config['solver']

        self.max_steps = self.config_solver['max_steps']
        self.save_steps = self.config_solver['save_steps']
        self.log_steps = self.config_solver['log_steps']
        self.resume_steps = self.config_solver['resume_steps']

        self.dataset = get_dataset(self.config_data)
        
        self.exp_root = self.config_solver['exp_root']
        if not os.path.exists(self.exp_root):
            os.makedirs(self.exp_root)
        self.ckpt_root = os.path.join(self.exp_root, 'models')
        if not os.path.exists(self.ckpt_root):
            os.makedirs(self.ckpt_root)
        self.log_root = os.path.join(self.exp_root, 'logs')
        if not os.path.exists(self.log_root):
            os.makedirs(self.log_root)

        self.model = get_model(self.config_model)
        # self.loss = BalancedLoss(self.config_solver)
        self.loss = get_loss(self.config_solver['loss'])
        self.optimizer = get_optimizer(self.config_solver)
        
        self.pr_meter = BinaryPRMeter()
        self.dice_meter = DiceMeter()
        self.loss_meter = tf.metrics.Mean()

        log_time = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
        self.logger = Logger(os.path.join(self.log_root, f'{log_time}.log'))  

        self.input_shape = (self.config_data['kwargs']['batch_size'], *self.config_model['input_shape'])
        # pdb.set_trace()
        self.model.build(self.input_shape)
        self.model.summary(print_fn=self.logger.info)      
        
    @tf.function
    def train_step(self, image_tensor, label_tensor):
        
        with tf.GradientTape() as tape:
            logit_tensor = self.model(image_tensor, training=True)
            train_loss_value = self.loss(label_tensor, logit_tensor)
            
        grads = tape.gradient(train_loss_value, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
        pred_tensor = tf.nn.softmax(logit_tensor)[:,:,:,:,0:1]
        # print(tf.math.count_nonzero(pred_tensor > 0.5))
        return train_loss_value, pred_tensor

    def save_model(self, step):

        checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer)
        ckpt_path = os.path.join(self.ckpt_root, f'step_{step:d}')
        checkpoint.write(ckpt_path)

    def load_model(self, step):

        checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer)
        ckpt_path = os.path.join(self.ckpt_root, f'step_{step:d}')
        checkpoint.restore(ckpt_path).expect_partial()

    def run(self):
        
        self.logger.info('Training started.')
        start_tick = time.time()
        train_loader = self.dataset.get_loader(training=True)

        if self.resume_steps > 0:
            step = self.resume_steps
            self.load_model(step)
            self.logger.info(f'[Step {step}] Checkpoint resumed.')
        else:
            step = 0
            self.logger.info('Train from scratch.')

        step += 1

        for image_tensor, label_tensor in train_loader:
            
            image_tensor = (image_tensor - tf.reduce_mean(image_tensor)) / tf.math.reduce_std(image_tensor)
            train_loss_value, pred_tensor = self.train_step(image_tensor, label_tensor)
            self.dice_meter.update_state(label_tensor, pred_tensor)
            self.pr_meter.update_state(label_tensor, pred_tensor)
            self.loss_meter.update_state(train_loss_value)

            if step % self.log_steps == 0:
                train_p, train_r = self.pr_meter.result()
                train_dice = self.dice_meter.result()
                train_loss_value = self.loss_meter.result()
                self.pr_meter.reset_states()
                self.dice_meter.reset_states()
                self.loss_meter.reset_states()
                elapsed_time = time.time() - start_tick
                et = str(datetime.timedelta(seconds=elapsed_time))[:-7]
                cur_lr = self.optimizer._decayed_lr(tf.float32)

                self.logger.info(f'[{et} Step {step:d} / {self.max_steps}] Training loss = {train_loss_value:.6f}, lr = {cur_lr:.6f}, Precision = {train_p:.4f}, Recall = {train_r:.4f}, Dice = {train_dice:.4f}.')
            
            if self.save_steps > 0 and step % self.save_steps == 0:
                self.save_model(step)
                self.logger.info(f'[Step {step}] checkpoint saved.')
            
            if step >= self.max_steps:
                break

            step += 1
            
        
        self.logger.info('Training finished.')
예제 #2
0
class SolverTest_2d(object):
    def __init__(self, config):

        self.config_data = config['data']
        self.config_model = config['model']
        self.config_solver = config['solver']

        self.test_steps = self.config_solver['test_steps']
        self.export = self.config_solver['export_results']
        self.dataset = get_dataset(self.config_data)

        self.exp_root = self.config_solver['exp_root']
        if not os.path.exists(self.exp_root):
            os.makedirs(self.exp_root)
        self.ckpt_root = os.path.join(self.exp_root, 'models')
        if not os.path.exists(self.ckpt_root):
            os.makedirs(self.ckpt_root)
        self.log_root = os.path.join(self.exp_root, 'logs')
        if not os.path.exists(self.log_root):
            os.makedirs(self.log_root)
        self.result_root = os.path.join(self.exp_root, 'results')
        if not os.path.exists(self.result_root):
            os.makedirs(self.result_root)

        self.model = get_model(self.config_model)
        self.optimizer = get_optimizer(self.config_solver)

        self.pr_meter = BinaryPRMeter()
        self.dice_meter = DiceMeter()

        log_time = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
        self.logger = Logger(os.path.join(self.log_root, f'{log_time}.log'))

        self.input_shape = (self.config_data['kwargs']['batch_size'],
                            *self.config_model['input_shape'])
        self.model.build(self.input_shape)
        self.model.summary(print_fn=self.logger.info)

    # @tf.function
    def chop_forward(self, input_tensor, gt):

        # pdb.set_trace()
        converted_tensor = input_tensor[0, ...]
        converted_tensor = tf.transpose(converted_tensor, (2, 0, 1, 3))
        batch_size = self.input_shape[0]
        z_layers = input_tensor.shape[3]
        nb_batches = tf.math.ceil(z_layers / batch_size)
        logits = []
        for z in range(nb_batches):
            begin = tf.cast(z * batch_size, tf.int32)
            end = tf.cast(tf.minimum((z + 1) * batch_size, z_layers), tf.int32)
            sliced_tensor = converted_tensor[begin:end, ...]
            sliced_tensor = (sliced_tensor - tf.reduce_mean(sliced_tensor)
                             ) / tf.math.reduce_std(sliced_tensor)
            logit_sliced_tensor = self.model(sliced_tensor, training=False)
            logits.append(logit_sliced_tensor)
        logit_tensor = tf.concat(logits, axis=0)
        logit_tensor = tf.transpose(logit_tensor, (1, 2, 0, 3))
        logit_tensor = tf.expand_dims(logit_tensor, 0)

        return logit_tensor

    def load_model(self, step):

        checkpoint = tf.train.Checkpoint(model=self.model,
                                         optimizer=self.optimizer)
        ckpt_path = os.path.join(self.ckpt_root, f'step_{step:d}')
        checkpoint.restore(ckpt_path).expect_partial()

    def run(self):

        self.logger.info('Testing started.')
        # start_tick = time.time()
        test_loader = self.dataset.get_loader(training=False)

        step = self.test_steps
        self.load_model(step)
        self.logger.info(f'[Step {step:d}] Model Loaded.')

        for test_idx, (image_tensor, label_tensor,
                       affine_tensor) in enumerate(test_loader):

            # pdb.set_trace()
            logit_tensor = self.chop_forward(image_tensor, label_tensor)
            pred_tensor = tf.nn.softmax(logit_tensor)[..., 0:1]

            print(
                f'positive predictions:{tf.math.count_nonzero(pred_tensor > 0.5).numpy()}.'
            )

            self.pr_meter.update_state(label_tensor, pred_tensor)
            self.dice_meter.update_state(label_tensor, pred_tensor)
            test_p, test_r = precision_recall(label_tensor, pred_tensor)
            test_dice = dice_coef(label_tensor, pred_tensor)

            if self.export:

                save_image_nii(
                    image_tensor, affine_tensor,
                    os.path.join(self.result_root,
                                 f'{step}_{test_idx + 1}_img.nii'))
                save_pred_nii(
                    pred_tensor, affine_tensor,
                    os.path.join(self.result_root,
                                 f'{step}_{test_idx + 1}_pred.nii'))
                save_label_nii(
                    label_tensor, affine_tensor,
                    os.path.join(self.result_root,
                                 f'{step}_{test_idx + 1}_gt.nii'))

            self.logger.info(
                f'[Test {test_idx + 1}] Precision = {test_p:.4f}, Recall = {test_r:.4f}, Dice = {test_dice:.4f}.'
            )

        test_p, test_r = self.pr_meter.result()
        test_dice = self.dice_meter.result()
        self.logger.info(
            f'[Total Average] Precision = {test_p:.4f}, Recall = {test_r:.4f}, Dice = {test_dice:.4f}.'
        )
예제 #3
0
파일: test.py 프로젝트: Vladimir2506/SegMRI
class SolverTest(object):
    def __init__(self, config):

        self.config_data = config['data']
        self.config_model = config['model']
        self.config_solver = config['solver']

        self.test_steps = self.config_solver['test_steps']
        self.export = self.config_solver['export_results']
        self.hybrid = self.config_solver['hybrid']

        self.dataset = get_dataset(self.config_data)

        self.exp_root = self.config_solver['exp_root']
        if not os.path.exists(self.exp_root):
            os.makedirs(self.exp_root)
        self.ckpt_root = os.path.join(self.exp_root, 'models')
        if not os.path.exists(self.ckpt_root):
            os.makedirs(self.ckpt_root)
        self.log_root = os.path.join(self.exp_root, 'logs')
        if not os.path.exists(self.log_root):
            os.makedirs(self.log_root)
        self.result_root = os.path.join(self.exp_root, 'results')
        if not os.path.exists(self.result_root):
            os.makedirs(self.result_root)

        self.model = get_model(self.config_model)
        self.optimizer = get_optimizer(self.config_solver)

        self.pr_meter = BinaryPRMeter()
        self.dice_meter = DiceMeter()
        self.loss_meter = tf.metrics.Mean()

        log_time = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
        self.logger = Logger(os.path.join(self.log_root, f'{log_time}.log'))
        self.batch_size = self.config_data['kwargs']['batch_size']
        self.input_shape = (self.batch_size, *self.config_model['input_shape'])
        self.model.build(self.input_shape)
        self.model.summary(print_fn=self.logger.info)

    def chop_forward(self, input_tensor):

        N, H, W, D, C = self.input_shape
        # z_layers = input_tensor.shape[3]
        pd = D // 2
        padded_tensor = tf.pad(input_tensor,
                               ((0, 0), (0, 0), (0, 0), (pd, pd), (0, 0)),
                               'symmetric')
        pz_layers = padded_tensor.shape[3]
        num_logits = np.zeros([1, H, W, pz_layers, 2])
        den_logits = np.zeros([1, H, W, pz_layers, 2])
        for d_idx in range(pz_layers):
            num_logits[:, :, :, d_idx:d_idx + D, :] += self.model(
                padded_tensor[:, :, :, d_idx:d_idx + D, :], training=False)
            den_logits[:, :, :, d_idx:d_idx + D, :] += 1.0
        logit_tensor = num_logits[:, :, :, pd:-pd, :] / den_logits[:, :, :,
                                                                   pd:-pd, :]
        return logit_tensor

    def save_model(self, step):

        checkpoint = tf.train.Checkpoint(model=self.model,
                                         optimizer=self.optimizer)
        ckpt_path = os.path.join(self.ckpt_root, f'step_{step:d}')
        checkpoint.write(ckpt_path)

    def load_model(self, step):

        checkpoint = tf.train.Checkpoint(model=self.model,
                                         optimizer=self.optimizer)
        ckpt_path = os.path.join(self.ckpt_root, f'step_{step:d}')
        checkpoint.restore(ckpt_path).expect_partial()

    def run(self):

        self.logger.info('Testing started.')
        # start_tick = time.time()
        test_loader = self.dataset.get_loader(training=False)

        step = self.test_steps
        self.load_model(step)
        self.logger.info(f'[Step {step:d}] Model Loaded.')

        for test_idx, (image_tensor, label_tensor,
                       affine_tensor) in enumerate(test_loader):

            input_tensor = (image_tensor - tf.reduce_mean(image_tensor)
                            ) / tf.math.reduce_std(image_tensor)
            if self.hybrid:
                logit_tensor = self.chop_forward(input_tensor)
            else:
                inputs_list = decompose_vol2cube(tf.squeeze(input_tensor),
                                                 self.batch_size, 64, 1, 4)
                logits_list = [
                    self.model(x, training=False) for x in inputs_list
                ]
                logit_tensor = compose_prob_cube2vol(logits_list,
                                                     image_tensor.shape[1:4],
                                                     self.batch_size, 64, 4, 2)

            pred_tensor = tf.nn.softmax(logit_tensor)[..., 0:1]
            self.pr_meter.update_state(label_tensor, pred_tensor)
            self.dice_meter.update_state(label_tensor, pred_tensor)
            test_p, test_r = precision_recall(label_tensor, pred_tensor)
            test_dice = dice_coef(label_tensor, pred_tensor)

            self.logger.info(
                f'[Test {test_idx + 1}] Precision = {test_p:.4f}, Recall = {test_r:.4f}, Dice = {test_dice:.4f}.'
            )

            if self.export:
                save_image_nii(
                    image_tensor, affine_tensor,
                    os.path.join(self.result_root,
                                 f'{step}_{test_idx}_img.nii'))
                save_pred_nii(
                    pred_tensor, affine_tensor,
                    os.path.join(self.result_root,
                                 f'{step}_{test_idx}_pred.nii'))
                save_label_nii(
                    label_tensor, affine_tensor,
                    os.path.join(self.result_root,
                                 f'{step}_{test_idx}_gt.nii'))

        test_p, test_r = self.pr_meter.result()
        test_dice = self.dice_meter.result()
        self.logger.info(
            f'[Total Average] Precision = {test_p:.4f}, Recall = {test_r:.4f}, Dice = {test_dice:.4f}.'
        )