Example #1
0
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.batches * self.conf.batch_size)
        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.batches):
                self.train_batch(epoch_loss)
                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                     ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # Plot some example images
            self.img_callback.on_epoch_end(self.epoch)

            self.model.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')
                break
Example #2
0
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.batches * self.conf.batch_size)
        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.batches):
                self.train_batch(epoch_loss)
                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            self.set_swa_model_weights()
            for swa_m in self.get_swa_models():
                swa_m.on_epoch_end(self.epoch)

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(
                str('Epoch %d/%d: ' +
                    ', '.join([l + ' Loss = %.5f' for l in loss_names])) %
                ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1]
                                                        for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # print images
            self.img_callback.on_epoch_end(self.epoch)

            self.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')

                es.on_train_end(logs)
                cl.on_train_end(logs)
                for swa_m in self.get_swa_models():
                    swa_m.on_train_end()

                # Set final model parameters based on SWA
                self.model.D_Mask = self.swa_D_Mask.model
                self.model.D_Image1 = self.swa_D_Image1.model
                self.model.D_Image2 = self.swa_D_Image2.model
                self.model.Encoders_Anatomy[0] = self.swa_Enc_Anatomy1.model
                self.model.Encoders_Anatomy[1] = self.swa_Enc_Anatomy2.model
                self.model.Enc_Modality = self.swa_Enc_Modality.model
                self.model.Anatomy_Fuser = self.swa_Anatomy_Fuser.model
                self.model.Segmentor = self.swa_Segmentor.model
                self.model.Decoder = self.swa_Decoder.model
                self.model.Balancer = self.swa_Balancer.model

                self.save_models()
                break
Example #3
0
if conf['red_lr_plateau']:
    redlr = ReduceLROnPlateau(monitor='loss_va',
                              factor=conf['red_lr_factor'],
                              patience=conf['red_lr_patience'],
                              verbose=True,
                              mode='min',
                              min_delta=conf['red_lr_eps'],
                              min_lr=conf['red_lr_min_lr'])
    redlr.model = snmt.model
if conf['early_stopping']:
    earlstop = EarlyStopping(monitor='loss_va',
                             min_delta=conf['early_stopping_eps'],
                             patience=conf['early_stopping_patience'],
                             verbose=True,
                             mode='min')
    earlstop.model = snmt.model
    earlstop.on_train_begin()

# Prepare savers.
saver = callbacks.TrainStateSaver(path_trrun,
                                  model=snmt.model,
                                  optimizer=optimizer,
                                  verbose=True)

# Create data loaders.
prn = (None, conf['path_normals'])[conf['normals_stream']]
prd = (None, conf['path_dmaps'])[conf['depth_stream']]
prm = (None, conf['path_meshes'])[conf['mesh_stream']]
tf_dm = TfReshape(tuple(conf['input_shape'][:2]) + (1, ))

ds_tr = DatasetImgNDM(conf['path_imgs'],
Example #4
0
    def train(self):
        def _learning_rate_schedule(epoch):
            return self.conf.lr * math.exp(self.lr_schedule_coef * (-epoch - 1))

        if os.path.exists(os.path.join(self.conf.folder, 'test-performance.csv')):
            os.remove(os.path.join(self.conf.folder, 'test-performance.csv'))
        if os.path.exists(os.path.join(self.conf.folder, 'validation-performance.csv')):
            os.remove(os.path.join(self.conf.folder, 'validation-performance.csv'))

        log.info('Training Model')
        dice_record = 0
        self.eval_train_interval = int(max(1, self.conf.epochs/50))

        self.init_train_data()
        lr_callback = LearningRateScheduler(_learning_rate_schedule)

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('Validate_Dice', self.conf.min_delta, self.conf.patience)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        loss_names.sort()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches)
        # self.img_clb.on_epoch_end(self.epoch)

        best_performance = 0.
        test_performance = 0.
        total_iters = 0
        for self.epoch in range(self.conf.epochs):
            total_iters+=1
            log.info('Epoch %d/%d' % (self.epoch+1, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.conf.batches):
                total_iters += 1
                self.train_batch(epoch_loss, lr_callback)
                progress_bar.update(self.batch + 1)

            val_dice = self.validate(epoch_loss)
            if val_dice > dice_record:
                dice_record = val_dice

            cl.model = self.model.D_Reconstruction
            cl.model.stop_training = False

            self.model.save_models()

            # Plot some example images
            if self.epoch % self.eval_train_interval == 0 or self.epoch == self.conf.epochs - 1:
                self.img_clb.on_epoch_end(self.epoch)
                folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'),
                                      'test_results_%s_epoch%d'
                                      % (self.conf.test_dataset, self.epoch))
                if not os.path.exists(folder):
                    os.makedirs(folder)
                test_performance = self.test_modality(folder, self.conf.modality, 'test', False)
                if test_performance > best_performance:
                    best_performance = test_performance
                    self.model.save_models('BestModel')
                    log.info("BestModel@Epoch%d" % self.epoch)

                folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'),
                                      'validation_results_%s_epoch%d'
                                      % (self.conf.test_dataset, self.epoch))
                if not os.path.exists(folder):
                    os.makedirs(folder)
                validation_performance = self.test_modality(folder, self.conf.modality, 'validation', False)
                if self.conf.batches>check_batch_iters:
                    self.write_csv(os.path.join(self.conf.folder, 'test-performance.csv'),
                                   self.epoch, self.batch, test_performance)
                    self.write_csv(os.path.join(self.conf.folder, 'validation-performance.csv'),
                                   self.epoch, self.batch, validation_performance)
            epoch_loss['Test_Performance_Dice'].append(test_performance)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))

            if self.epoch<5:
                log.info(str('Epoch %d/%d:\n' + ''.join([l + ' Loss = %.3f\n' for l in loss_names])) %
                         ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            else:
                info_str = str('Epoch %d/%d:\n' % (self.epoch, self.conf.epochs))
                loss_info = ''
                for l in loss_names:
                    loss_info = loss_info + l + ' Loss = %.3f->%.3f->%.3f->%.3f->%.3f\n' % \
                                (total_loss[l][-5],
                                 total_loss[l][-4],
                                 total_loss[l][-3],
                                 total_loss[l][-2],
                                 total_loss[l][-1])
                log.info(info_str + loss_info)
            log.info("BestTest:%f" % best_performance)
            log.info('Epoch %d/%d' % (self.epoch + 1, self.conf.epochs))
            logs = {l: total_loss[l][-1] for l in loss_names}
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            if self.stop_criterion(es, logs) and self.epoch > self.conf.epochs / 2:
                log.info('Finished training from early stopping criterion')
                self.img_clb.on_epoch_end(self.epoch)
                break
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss', min_delta=0.01, patience=100)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size)

        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            D_initial_weights = np.mean(
                [np.mean(w) for w in self.model.D_trainer.get_weights()])
            G_initial_weights = np.mean(
                [np.mean(w) for w in self.model.G_trainer.get_weights()])
            for self.batch in range(self.conf.batches):
                # real_pools = self.add_to_pool(data, real_pools)
                self.train_batch(epoch_loss)

                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            G_final_weights = np.mean(
                [np.mean(w) for w in self.model.G_trainer.get_weights()])
            D_final_weights = np.mean(
                [np.mean(w) for w in self.model.D_trainer.get_weights()])

            assert self.gen_unlabelled is None or not self.model.D_trainer.trainable \
                   or D_initial_weights != D_final_weights
            assert G_initial_weights != G_final_weights

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(
                str('Epoch %d/%d: ' +
                    ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1]
                                                        for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # Plot some example images
            self.img_clb.on_epoch_end(self.epoch)

            self.model.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')
                break