Exemplo n.º 1
0
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.batches * self.conf.batch_size)
        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.batches):
                self.train_batch(epoch_loss)
                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                     ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # Plot some example images
            self.img_callback.on_epoch_end(self.epoch)

            self.model.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')
                break
Exemplo n.º 2
0
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.batches * self.conf.batch_size)
        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.batches):
                self.train_batch(epoch_loss)
                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            self.set_swa_model_weights()
            for swa_m in self.get_swa_models():
                swa_m.on_epoch_end(self.epoch)

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(
                str('Epoch %d/%d: ' +
                    ', '.join([l + ' Loss = %.5f' for l in loss_names])) %
                ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1]
                                                        for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # print images
            self.img_callback.on_epoch_end(self.epoch)

            self.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')

                es.on_train_end(logs)
                cl.on_train_end(logs)
                for swa_m in self.get_swa_models():
                    swa_m.on_train_end()

                # Set final model parameters based on SWA
                self.model.D_Mask = self.swa_D_Mask.model
                self.model.D_Image1 = self.swa_D_Image1.model
                self.model.D_Image2 = self.swa_D_Image2.model
                self.model.Encoders_Anatomy[0] = self.swa_Enc_Anatomy1.model
                self.model.Encoders_Anatomy[1] = self.swa_Enc_Anatomy2.model
                self.model.Enc_Modality = self.swa_Enc_Modality.model
                self.model.Anatomy_Fuser = self.swa_Anatomy_Fuser.model
                self.model.Segmentor = self.swa_Segmentor.model
                self.model.Decoder = self.swa_Decoder.model
                self.model.Balancer = self.swa_Balancer.model

                self.save_models()
                break
    def train(self):
        self.init_train_data()
        # make genetrated data
        gen_dict = self.get_datagen_params()
        p_gen = ImageDataGenerator(**gen_dict).flow(x=self.p_images, y=self.p_masks, batch_size=self.conf.batch_size)
        h_gen = ImageDataGenerator(**gen_dict).flow(x=self.h_images, y=self.h_masks, batch_size=self.conf.batch_size)
        random_p_masks = ImageDataGenerator(**gen_dict).flow(x= self.p_masks, batch_size=self.conf.batch_size)

        # initialize training
        batches = int(np.ceil(self.conf.data_len/self.conf.batch_size))
        progress_bar = Progbar(target=batches * self.conf.batch_size)

        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder+'/training.csv')
        cl.on_train_begin()
        img_clb = ImageCallback(self.conf, self.model, self.comet_exp)

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        # start training
        for epoch in range(self.conf.epochs):
            log.info("Train epoch %d/%d"%(epoch, self.conf.epochs))
            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []
            pool_to_print_p_img, pool_to_print_p_msk, pool_to_print_h_img, pool_to_print_h_msk = [], [], [], []

            for batch in range(batches):
                p_img, p_msk = next(p_gen)
                h_img, h_msk = next(h_gen)
                r_p_msk = next(random_p_masks)

                if len(pool_to_print_p_img)<30:
                    pool_to_print_p_img.append(p_img[0])
                    pool_to_print_p_msk.append(p_msk[0])


                if len(pool_to_print_h_img)<30:
                    pool_to_print_h_img.append(h_img[0])
                    pool_to_print_h_msk.append(h_msk[0])

                # Adversarial ground truths
                real_pred = -np.ones((h_img.shape[0],1))
                fake_pred = np.ones((h_img.shape[0],1))
                dummy = np.zeros((h_img.shape[0],1))
                dummy_Img = np.ones(h_img.shape)

                if self.conf.self_rec:
                    h_test_sr = self.model.train_self_rec.fit([h_img, h_msk], [h_img, h_img], epochs=1, verbose=0)
                    epoch_loss["test_self_rec_loss"].append(np.mean(h_test_sr.history["loss"]))
                else:
                    epoch_loss["test_self_rec_loss"].append(0)

                # ---------------------
                #  Train Discriminator
                # ---------------------
                # Get a group of synthetic msks and imgs
                cy1_pse_h_img = self.model.G_d_to_h.predict(p_img)
                cy1_seg_d_msk = self.model.S_d_to_msk.predict(p_img)
                cy2_fake_h_img = self.model.G_h_to_d.predict([h_img, h_msk])

                if epoch<25:
                    for _ in range(self.conf.ncritic[0]):
                        cy1_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1))
                        cy1_average = cy1_epsilon * h_img +(1-cy1_epsilon) * cy1_pse_h_img

                        cy1_epsilon_msk = np.random.uniform(0, 1, size=(h_img.shape[0], 1, 1, 1))
                        cy1_average_msk = cy1_epsilon_msk * r_p_msk + (1 - cy1_epsilon) * cy1_seg_d_msk

                        cy2_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1))
                        cy2_average = cy2_epsilon * h_img +(1-cy2_epsilon) * cy2_fake_h_img

                        h_d = self.model.critic_model.fit([h_img, cy1_pse_h_img, cy1_average,
                                                           r_p_msk, cy1_seg_d_msk, cy1_average_msk,
                                                           h_img, cy2_fake_h_img, cy2_average],
                                                          [real_pred, fake_pred, dummy, real_pred, fake_pred, dummy,
                                                           real_pred, fake_pred, dummy],
                                                          epochs=1, verbose=0)
                else:
                    for _ in range(self.conf.ncritic[1]):
                        cy1_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1))
                        cy1_average = cy1_epsilon * h_img +(1-cy1_epsilon) * cy1_pse_h_img

                        cy1_epsilon_msk = np.random.uniform(0, 1, size=(h_img.shape[0], 1, 1, 1))
                        cy1_average_msk = cy1_epsilon_msk * r_p_msk + (1 - cy1_epsilon) * cy1_seg_d_msk

                        cy2_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1))
                        cy2_average = cy2_epsilon * h_img +(1-cy2_epsilon) * cy2_fake_h_img

                        h_d = self.model.critic_model.fit([h_img, cy1_pse_h_img, cy1_average,
                                                           r_p_msk, cy1_seg_d_msk, cy1_average_msk,
                                                           h_img, cy2_fake_h_img, cy2_average],
                                                          [real_pred, fake_pred, dummy, real_pred, fake_pred, dummy,
                                                           real_pred, fake_pred, dummy],
                                                          epochs=1, verbose=0)
                # print(h_d.history)
                d_dis_pse_image_loss = np.mean([h_d.history['dis_cy1_I_pse_h_loss'], h_d.history['dis_cy2_I_pse_h_loss']])
                d_dis_r_image_loss   = np.mean([h_d.history['dis_cy1_I_h_loss'], h_d.history['dis_cy2_I_h_loss']])
                d_dis_d_mask_loss    = np.mean([h_d.history['dis_cy1_M_d_loss'], h_d.history['dis_cy1_M_seg_d_loss']])
                d_gp_loss            = np.mean([h_d.history['gp_cy1_I_h_loss'], h_d.history['gp_cy2_I_h_loss'], h_d.history['gp_cy1_M_d_loss']])
                epoch_loss['d_dis_pse_image_loss'].append(d_dis_pse_image_loss)
                epoch_loss['d_dis_r_image_loss'].append(d_dis_r_image_loss)
                epoch_loss['d_dis_d_mask_loss'].append(d_dis_d_mask_loss)
                epoch_loss['d_gp_loss'].append(d_gp_loss)

                # --------------------
                #  Train Generator
                # --------------------

                h_g = self.model.gan.fit([p_img, h_img, h_msk],[real_pred, real_pred, p_img, real_pred, h_img, h_msk], epochs=1, verbose=0)
                g_dis_pse_image_loss = np.mean([h_g.history['cy1_dis_I_pse_h_loss'], h_g.history['cy2_dis_I_pse_d_loss']])
                g_rec_image_loss = np.mean([h_g.history['cy2_I_rec_h_loss'], h_g.history['cy1_I_rec_d_loss']])
                g_dis_d_mask_loss = np.mean(h_g.history['cy1_dis_M_seg_d_loss'])
                epoch_loss['g_dis_pse_image_loss'].append(g_dis_pse_image_loss)
                epoch_loss['g_rec_image_loss'].append(g_rec_image_loss)
                epoch_loss['g_dis_d_mask_loss'].append(g_dis_d_mask_loss)
                # print(h_g.history)
                # Plot the progress
                progress_bar.update((batch + 1) * self.conf.batch_size)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))

            log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                     ((epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_pse_h
            cl.model.stop_training = False
            cl.on_epoch_end(epoch, logs)
            sl.on_epoch_end(epoch, logs)
            pool_to_print_p_img = np.asarray(pool_to_print_p_img)
            pool_to_print_p_msk = np.asarray(pool_to_print_p_msk)
            pool_to_print_h_img = np.asarray(pool_to_print_h_img)
            pool_to_print_h_msk = np.asarray(pool_to_print_h_msk)
            print("pool_to_print_p_img: ", np.shape(pool_to_print_p_img))
            img_clb.on_epoch_end(epoch, pool_to_print_p_img, pool_to_print_p_msk,
                                 pool_to_print_h_img, pool_to_print_h_msk)
Exemplo n.º 4
0
    def fit(self):
        """
        Train SDNet
        """
        log.info('Training SDNet')

        # Load data
        self.init_train()

        # Initialise callbacks
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()
        si = SDNetCallback(self.conf.folder, self.conf.batch_size, self.sdnet)
        es = EarlyStopping('val_loss', min_delta=0.001, patience=20)
        es.on_train_begin()

        loss_names = [
            'adv_M', 'adv_X', 'rec_X', 'rec_M', 'rec_Z', 'dis_M', 'dis_X',
            'mask', 'image', 'val_loss'
        ]

        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size)

        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            real_lb_pool, real_ul_pool = [], [
            ]  # these are used only for printing images

            epoch_loss = {n: [] for n in loss_names}

            D_initial_weights = np.mean(
                [np.mean(w) for w in self.sdnet.D_model.get_weights()])
            G_initial_weights = np.mean(
                [np.mean(w) for w in self.sdnet.G_model.get_weights()])
            for self.batch in range(self.conf.batches):
                real_lb = next(self.gen_X_L)
                real_ul = next(self.gen_X_U)

                # Add image/mask batch to the data pool
                x, m = real_lb
                real_lb_pool.extend([(x[i:i + 1], m[i:i + 1])
                                     for i in range(x.shape[0])])
                real_ul_pool.extend(real_ul)

                D_weights1 = np.mean(
                    [np.mean(w) for w in self.sdnet.D_model.get_weights()])
                self.train_batch_generator(real_lb, real_ul, epoch_loss)
                D_weights2 = np.mean(
                    [np.mean(w) for w in self.sdnet.D_model.get_weights()])
                assert D_weights1 == D_weights2

                self.train_batch_discriminator(real_lb, real_ul, epoch_loss)

                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            G_final_weights = np.mean(
                [np.mean(w) for w in self.sdnet.G_model.get_weights()])
            D_final_weights = np.mean(
                [np.mean(w) for w in self.sdnet.D_model.get_weights()])

            # Check training is altering weights
            assert D_initial_weights != D_final_weights
            assert G_initial_weights != G_final_weights

            # Plot some example images
            si.on_epoch_end(self.epoch, np.array(real_lb_pool),
                            np.array(real_ul_pool))

            self.validate(epoch_loss)

            # Calculate epoch losses
            for n in loss_names:
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % \
                  ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}
            sl.on_epoch_end(self.epoch, logs)

            # log losses to csv
            cl.model = self.sdnet.D_model
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)

            # save models
            self.sdnet.save_models()

            # early stopping
            if self.stop_criterion(es, self.epoch, logs):
                log.info('Finished training from early stopping criterion')
                break
Exemplo n.º 5
0
    def train(self):
        self.init_train_data()
        # make genetrated data
        gen_dict = self.get_datagen_params()

        # Here we need to concatenate age and AD labels, in order to use Function ImageDataGenerator
        yng_labels = np.concatenate([self.train_age_yng, self.train_AD_yng],
                                    axis=1)
        old_labels = np.concatenate([self.train_age_old, self.train_AD_old],
                                    axis=1)

        old_gen = ImageDataGenerator(**gen_dict).flow(
            x=self.train_img_old,
            y=old_labels,
            batch_size=self.conf.batch_size)
        yng_gen = ImageDataGenerator(**gen_dict).flow(
            x=self.train_img_yng,
            y=yng_labels,
            batch_size=self.conf.batch_size)

        # initialize training
        batches = int(np.ceil(self.conf.data_len / self.conf.batch_size))
        progress_bar = Progbar(target=batches * self.conf.batch_size)

        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()
        img_clb = ImageCallback(self.conf, self.model, self.comet_exp)

        # clr = CyclicLR(base_lr=self.conf.lr/5, max_lr=self.conf.lr,
        #                step_size=batches*4, mode='triangular')

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        # start training
        for epoch in range(self.conf.epochs):
            log.info("Train epoch %d/%d" % (epoch, self.conf.epochs))
            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []
            pool_to_print_old, pool_to_print_yng = [], []

            for batch in range(batches):
                old_img, old_labels = next(old_gen)
                yng_img, yng_labels = next(yng_gen)

                # Return labels to age and AD vectors
                old_age = old_labels[:, :self.conf.age_dim, :]
                old_AD = old_labels[:, self.conf.age_dim:, :]

                yng_age = yng_labels[:, :self.conf.age_dim, :]
                yng_AD = yng_labels[:, self.conf.age_dim:, :]

                if len(pool_to_print_old) < 30:
                    pool_to_print_old.append(old_img)

                if len(pool_to_print_yng) < 30:
                    pool_to_print_yng.append(yng_img)

                # Adversarial ground truths
                real_pred = -np.ones((old_img.shape[0], 1))
                fake_pred = np.ones((old_img.shape[0], 1))
                dummy = np.zeros((old_img.shape[0], 1))
                dummy_Img = np.ones(old_img.shape)
                # ---------------------
                #  Train Discriminator
                # ---------------------
                age_gap = calculate_age_diff(yng_age, old_age)
                diff_age = get_age_ord_vector(age_gap,
                                              expand_dim=1,
                                              con=self.conf.age_con,
                                              ord=self.conf.age_ord,
                                              age_dim=self.conf.age_dim)
                # Get a group of synthetic msks and imgs
                gen_masks = self.model.generator.predict(
                    [yng_img, diff_age, old_AD])
                gen_old_img = np.tanh(
                    gen_masks +
                    yng_img) if self.conf.use_tanh else gen_masks + yng_img
                # Need to train discriminators more iterations:
                if epoch < 25:
                    for _ in range(self.conf.ncritic[0]):
                        epsilon = np.random.uniform(0,
                                                    1,
                                                    size=(old_img.shape[0], 1,
                                                          1, 1))
                        interpolation = epsilon * old_img + (
                            1 - epsilon) * gen_old_img
                        h_d = self.model.critic_model.fit([
                            old_img, old_age, old_AD, gen_old_img, old_age,
                            old_AD, interpolation, old_age, old_AD
                        ], [real_pred, fake_pred, dummy],
                                                          epochs=1,
                                                          verbose=0)
                        # , callbacks=[clr])
                    # d_loss_bce = np.mean([h_real.history['binary_crossentropy'], h_fake.history['binary_crossentropy']])
                else:
                    for _ in range(self.conf.ncritic[1]):
                        epsilon = np.random.uniform(0,
                                                    1,
                                                    size=(old_img.shape[0], 1,
                                                          1, 1))
                        interpolation = epsilon * old_img + (
                            1 - epsilon) * gen_old_img
                        h_d = self.model.critic_model.fit([
                            old_img, old_age, old_AD, gen_old_img, old_age,
                            old_AD, interpolation, old_age, old_AD
                        ], [real_pred, fake_pred, dummy],
                                                          epochs=1,
                                                          verbose=0)
                        # , callbacks=[clr])

                # d_loss_bce = np.mean(h_real.history['d_loss'])
                print('d_real_loss', np.mean(h_d.history['d_real_loss']),
                      'd_fake_loss', np.mean(h_d.history['d_fake_loss']))
                d_loss_bce = np.mean(
                    [h_d.history['d_real_loss'], h_d.history['d_fake_loss']])
                d_loss_real = np.mean(h_d.history['d_real_loss'])
                d_loss_fake = np.mean(h_d.history['d_fake_loss'])
                d_loss_gp = np.mean(h_d.history['gp_loss'])
                epoch_loss['Discriminator_loss'].append(d_loss_bce)
                epoch_loss['Discriminator_real_loss'].append(d_loss_real)
                epoch_loss['Discriminator_fake_loss'].append(d_loss_fake)
                epoch_loss['Discriminator_gp_loss'].append(d_loss_gp)
                # --------------------
                #  Train Generator
                # --------------------
                # Train the generator, want discriminator to mistake images as real
                h = self.model.gan.fit(
                    [yng_img, old_age, diff_age, age_gap, old_AD],
                    [real_pred, dummy_Img],
                    epochs=1,
                    verbose=0)
                # , callbacks=[clr])
                # print(h.history)
                g_loss_bce = h.history['discriminator_loss']
                g_loss_l1 = h.history['map_l1_reg_loss']

                # Deal with epoch loss

                epoch_loss['Generator_fake_loss'].append(g_loss_bce)
                epoch_loss['Generator_l1_reg_loss'].append(g_loss_l1)

                #-----------------------------------------
                # Train Generator by self-regularization
                #-----------------------------------------
                diff_age_zero = yng_age - yng_age
                h = self.model.GAN_zero_reg([yng_img, diff_age_zero, yng_AD],
                                            yng_img,
                                            epochs=1,
                                            verbose=0)
                g_zero_reg = np.mean(h.history['self_reg'])
                epoch_loss['Generator_zero_gre_loss'].append(g_zero_reg)

                # Plot the progress
                progress_bar.update((batch + 1) * self.conf.batch_size)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))

            log.info(
                str('Epoch %d/%d: ' +
                    ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                ((epoch, self.conf.epochs) + tuple(total_loss[l][-1]
                                                   for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.discriminator
            cl.model.stop_training = False
            cl.on_epoch_end(epoch, logs)
            sl.on_epoch_end(epoch, logs)
            img_clb.on_epoch_end(epoch, yng_img, yng_age, old_img, old_age)
Exemplo n.º 6
0
    def train(self):
        def _learning_rate_schedule(epoch):
            return self.conf.lr * math.exp(self.lr_schedule_coef * (-epoch - 1))

        if os.path.exists(os.path.join(self.conf.folder, 'test-performance.csv')):
            os.remove(os.path.join(self.conf.folder, 'test-performance.csv'))
        if os.path.exists(os.path.join(self.conf.folder, 'validation-performance.csv')):
            os.remove(os.path.join(self.conf.folder, 'validation-performance.csv'))

        log.info('Training Model')
        dice_record = 0
        self.eval_train_interval = int(max(1, self.conf.epochs/50))

        self.init_train_data()
        lr_callback = LearningRateScheduler(_learning_rate_schedule)

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('Validate_Dice', self.conf.min_delta, self.conf.patience)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        loss_names.sort()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches)
        # self.img_clb.on_epoch_end(self.epoch)

        best_performance = 0.
        test_performance = 0.
        total_iters = 0
        for self.epoch in range(self.conf.epochs):
            total_iters+=1
            log.info('Epoch %d/%d' % (self.epoch+1, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.conf.batches):
                total_iters += 1
                self.train_batch(epoch_loss, lr_callback)
                progress_bar.update(self.batch + 1)

            val_dice = self.validate(epoch_loss)
            if val_dice > dice_record:
                dice_record = val_dice

            cl.model = self.model.D_Reconstruction
            cl.model.stop_training = False

            self.model.save_models()

            # Plot some example images
            if self.epoch % self.eval_train_interval == 0 or self.epoch == self.conf.epochs - 1:
                self.img_clb.on_epoch_end(self.epoch)
                folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'),
                                      'test_results_%s_epoch%d'
                                      % (self.conf.test_dataset, self.epoch))
                if not os.path.exists(folder):
                    os.makedirs(folder)
                test_performance = self.test_modality(folder, self.conf.modality, 'test', False)
                if test_performance > best_performance:
                    best_performance = test_performance
                    self.model.save_models('BestModel')
                    log.info("BestModel@Epoch%d" % self.epoch)

                folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'),
                                      'validation_results_%s_epoch%d'
                                      % (self.conf.test_dataset, self.epoch))
                if not os.path.exists(folder):
                    os.makedirs(folder)
                validation_performance = self.test_modality(folder, self.conf.modality, 'validation', False)
                if self.conf.batches>check_batch_iters:
                    self.write_csv(os.path.join(self.conf.folder, 'test-performance.csv'),
                                   self.epoch, self.batch, test_performance)
                    self.write_csv(os.path.join(self.conf.folder, 'validation-performance.csv'),
                                   self.epoch, self.batch, validation_performance)
            epoch_loss['Test_Performance_Dice'].append(test_performance)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))

            if self.epoch<5:
                log.info(str('Epoch %d/%d:\n' + ''.join([l + ' Loss = %.3f\n' for l in loss_names])) %
                         ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            else:
                info_str = str('Epoch %d/%d:\n' % (self.epoch, self.conf.epochs))
                loss_info = ''
                for l in loss_names:
                    loss_info = loss_info + l + ' Loss = %.3f->%.3f->%.3f->%.3f->%.3f\n' % \
                                (total_loss[l][-5],
                                 total_loss[l][-4],
                                 total_loss[l][-3],
                                 total_loss[l][-2],
                                 total_loss[l][-1])
                log.info(info_str + loss_info)
            log.info("BestTest:%f" % best_performance)
            log.info('Epoch %d/%d' % (self.epoch + 1, self.conf.epochs))
            logs = {l: total_loss[l][-1] for l in loss_names}
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            if self.stop_criterion(es, logs) and self.epoch > self.conf.epochs / 2:
                log.info('Finished training from early stopping criterion')
                self.img_clb.on_epoch_end(self.epoch)
                break
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss', min_delta=0.01, patience=100)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size)

        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            D_initial_weights = np.mean(
                [np.mean(w) for w in self.model.D_trainer.get_weights()])
            G_initial_weights = np.mean(
                [np.mean(w) for w in self.model.G_trainer.get_weights()])
            for self.batch in range(self.conf.batches):
                # real_pools = self.add_to_pool(data, real_pools)
                self.train_batch(epoch_loss)

                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            G_final_weights = np.mean(
                [np.mean(w) for w in self.model.G_trainer.get_weights()])
            D_final_weights = np.mean(
                [np.mean(w) for w in self.model.D_trainer.get_weights()])

            assert self.gen_unlabelled is None or not self.model.D_trainer.trainable \
                   or D_initial_weights != D_final_weights
            assert G_initial_weights != G_final_weights

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(
                str('Epoch %d/%d: ' +
                    ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1]
                                                        for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # Plot some example images
            self.img_clb.on_epoch_end(self.epoch)

            self.model.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')
                break