Example #1
0
def train_gan(dataf):
    # Создаем модель
    gen, disc, gan = build_networks()
    logger = CSVLogger('loss.csv')
    logger.on_train_begin()

    # Запускаем обучение на 500 эпох
    with h5py.File(dataf, 'r') as f:
        faces = f.get('faces')
        run_batches(gen, disc, gan, faces, logger, range(5000))
    logger.on_train_end()
Example #2
0
def train_gan(dataf):
    gen, disc, gan = build_networks()

    # Uncomment these, if you want to continue training from some snapshot.
    # (or load pretrained generator weights)
    #load_weights(gen, Args.genw)
    #load_weights(disc, Args.discw)

    logger = CSVLogger('loss.csv') # yeah, you can use callbacks independently
    logger.on_train_begin() # initialize csv file
    with h5py.File( dataf, 'r' ) as f :
        faces = f.get( 'faces' )
        run_batches(gen, disc, gan, faces, logger, range(1000000))
    logger.on_train_end()
Example #3
0
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.batches * self.conf.batch_size)
        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.batches):
                self.train_batch(epoch_loss)
                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                     ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # Plot some example images
            self.img_callback.on_epoch_end(self.epoch)

            self.model.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')
                break
Example #4
0
def train_gan(dataf, iters=1000000, disc_start=20, cont=False):
    gen, disc, gan = build_networks()

    # Uncomment these, if you want to continue training from some snapshot.
    # (or load pretrained generator weights)
    if cont == True:
        #load_weights(gen, Args.genw)
        #load_weights(disc, Args.discw)
        load_weights(gen, "snapshots/{}.gen.hdf5".format(Args.batch_len - 1))
        load_weights(disc, "snapshots/{}.disc.hdf5".format(Args.batch_len - 1))

    logger = CSVLogger('loss.csv')  # yeah, you can use callbacks independently
    logger.on_train_begin()  # initialize csv file
    with h5py.File(dataf, 'r') as f:
        faces = f.get('faces')
        run_batches(gen, disc, gan, faces, logger, range(iters), disc_start)
    logger.on_train_end()
Example #5
0
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.batches * self.conf.batch_size)
        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.batches):
                self.train_batch(epoch_loss)
                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            self.set_swa_model_weights()
            for swa_m in self.get_swa_models():
                swa_m.on_epoch_end(self.epoch)

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(
                str('Epoch %d/%d: ' +
                    ', '.join([l + ' Loss = %.5f' for l in loss_names])) %
                ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1]
                                                        for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # print images
            self.img_callback.on_epoch_end(self.epoch)

            self.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')

                es.on_train_end(logs)
                cl.on_train_end(logs)
                for swa_m in self.get_swa_models():
                    swa_m.on_train_end()

                # Set final model parameters based on SWA
                self.model.D_Mask = self.swa_D_Mask.model
                self.model.D_Image1 = self.swa_D_Image1.model
                self.model.D_Image2 = self.swa_D_Image2.model
                self.model.Encoders_Anatomy[0] = self.swa_Enc_Anatomy1.model
                self.model.Encoders_Anatomy[1] = self.swa_Enc_Anatomy2.model
                self.model.Enc_Modality = self.swa_Enc_Modality.model
                self.model.Anatomy_Fuser = self.swa_Anatomy_Fuser.model
                self.model.Segmentor = self.swa_Segmentor.model
                self.model.Decoder = self.swa_Decoder.model
                self.model.Balancer = self.swa_Balancer.model

                self.save_models()
                break
    def train(self):
        self.init_train_data()
        # make genetrated data
        gen_dict = self.get_datagen_params()
        p_gen = ImageDataGenerator(**gen_dict).flow(x=self.p_images, y=self.p_masks, batch_size=self.conf.batch_size)
        h_gen = ImageDataGenerator(**gen_dict).flow(x=self.h_images, y=self.h_masks, batch_size=self.conf.batch_size)
        random_p_masks = ImageDataGenerator(**gen_dict).flow(x= self.p_masks, batch_size=self.conf.batch_size)

        # initialize training
        batches = int(np.ceil(self.conf.data_len/self.conf.batch_size))
        progress_bar = Progbar(target=batches * self.conf.batch_size)

        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder+'/training.csv')
        cl.on_train_begin()
        img_clb = ImageCallback(self.conf, self.model, self.comet_exp)

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        # start training
        for epoch in range(self.conf.epochs):
            log.info("Train epoch %d/%d"%(epoch, self.conf.epochs))
            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []
            pool_to_print_p_img, pool_to_print_p_msk, pool_to_print_h_img, pool_to_print_h_msk = [], [], [], []

            for batch in range(batches):
                p_img, p_msk = next(p_gen)
                h_img, h_msk = next(h_gen)
                r_p_msk = next(random_p_masks)

                if len(pool_to_print_p_img)<30:
                    pool_to_print_p_img.append(p_img[0])
                    pool_to_print_p_msk.append(p_msk[0])


                if len(pool_to_print_h_img)<30:
                    pool_to_print_h_img.append(h_img[0])
                    pool_to_print_h_msk.append(h_msk[0])

                # Adversarial ground truths
                real_pred = -np.ones((h_img.shape[0],1))
                fake_pred = np.ones((h_img.shape[0],1))
                dummy = np.zeros((h_img.shape[0],1))
                dummy_Img = np.ones(h_img.shape)

                if self.conf.self_rec:
                    h_test_sr = self.model.train_self_rec.fit([h_img, h_msk], [h_img, h_img], epochs=1, verbose=0)
                    epoch_loss["test_self_rec_loss"].append(np.mean(h_test_sr.history["loss"]))
                else:
                    epoch_loss["test_self_rec_loss"].append(0)

                # ---------------------
                #  Train Discriminator
                # ---------------------
                # Get a group of synthetic msks and imgs
                cy1_pse_h_img = self.model.G_d_to_h.predict(p_img)
                cy1_seg_d_msk = self.model.S_d_to_msk.predict(p_img)
                cy2_fake_h_img = self.model.G_h_to_d.predict([h_img, h_msk])

                if epoch<25:
                    for _ in range(self.conf.ncritic[0]):
                        cy1_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1))
                        cy1_average = cy1_epsilon * h_img +(1-cy1_epsilon) * cy1_pse_h_img

                        cy1_epsilon_msk = np.random.uniform(0, 1, size=(h_img.shape[0], 1, 1, 1))
                        cy1_average_msk = cy1_epsilon_msk * r_p_msk + (1 - cy1_epsilon) * cy1_seg_d_msk

                        cy2_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1))
                        cy2_average = cy2_epsilon * h_img +(1-cy2_epsilon) * cy2_fake_h_img

                        h_d = self.model.critic_model.fit([h_img, cy1_pse_h_img, cy1_average,
                                                           r_p_msk, cy1_seg_d_msk, cy1_average_msk,
                                                           h_img, cy2_fake_h_img, cy2_average],
                                                          [real_pred, fake_pred, dummy, real_pred, fake_pred, dummy,
                                                           real_pred, fake_pred, dummy],
                                                          epochs=1, verbose=0)
                else:
                    for _ in range(self.conf.ncritic[1]):
                        cy1_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1))
                        cy1_average = cy1_epsilon * h_img +(1-cy1_epsilon) * cy1_pse_h_img

                        cy1_epsilon_msk = np.random.uniform(0, 1, size=(h_img.shape[0], 1, 1, 1))
                        cy1_average_msk = cy1_epsilon_msk * r_p_msk + (1 - cy1_epsilon) * cy1_seg_d_msk

                        cy2_epsilon = np.random.uniform(0,1, size=(h_img.shape[0],1,1,1))
                        cy2_average = cy2_epsilon * h_img +(1-cy2_epsilon) * cy2_fake_h_img

                        h_d = self.model.critic_model.fit([h_img, cy1_pse_h_img, cy1_average,
                                                           r_p_msk, cy1_seg_d_msk, cy1_average_msk,
                                                           h_img, cy2_fake_h_img, cy2_average],
                                                          [real_pred, fake_pred, dummy, real_pred, fake_pred, dummy,
                                                           real_pred, fake_pred, dummy],
                                                          epochs=1, verbose=0)
                # print(h_d.history)
                d_dis_pse_image_loss = np.mean([h_d.history['dis_cy1_I_pse_h_loss'], h_d.history['dis_cy2_I_pse_h_loss']])
                d_dis_r_image_loss   = np.mean([h_d.history['dis_cy1_I_h_loss'], h_d.history['dis_cy2_I_h_loss']])
                d_dis_d_mask_loss    = np.mean([h_d.history['dis_cy1_M_d_loss'], h_d.history['dis_cy1_M_seg_d_loss']])
                d_gp_loss            = np.mean([h_d.history['gp_cy1_I_h_loss'], h_d.history['gp_cy2_I_h_loss'], h_d.history['gp_cy1_M_d_loss']])
                epoch_loss['d_dis_pse_image_loss'].append(d_dis_pse_image_loss)
                epoch_loss['d_dis_r_image_loss'].append(d_dis_r_image_loss)
                epoch_loss['d_dis_d_mask_loss'].append(d_dis_d_mask_loss)
                epoch_loss['d_gp_loss'].append(d_gp_loss)

                # --------------------
                #  Train Generator
                # --------------------

                h_g = self.model.gan.fit([p_img, h_img, h_msk],[real_pred, real_pred, p_img, real_pred, h_img, h_msk], epochs=1, verbose=0)
                g_dis_pse_image_loss = np.mean([h_g.history['cy1_dis_I_pse_h_loss'], h_g.history['cy2_dis_I_pse_d_loss']])
                g_rec_image_loss = np.mean([h_g.history['cy2_I_rec_h_loss'], h_g.history['cy1_I_rec_d_loss']])
                g_dis_d_mask_loss = np.mean(h_g.history['cy1_dis_M_seg_d_loss'])
                epoch_loss['g_dis_pse_image_loss'].append(g_dis_pse_image_loss)
                epoch_loss['g_rec_image_loss'].append(g_rec_image_loss)
                epoch_loss['g_dis_d_mask_loss'].append(g_dis_d_mask_loss)
                # print(h_g.history)
                # Plot the progress
                progress_bar.update((batch + 1) * self.conf.batch_size)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))

            log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                     ((epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_pse_h
            cl.model.stop_training = False
            cl.on_epoch_end(epoch, logs)
            sl.on_epoch_end(epoch, logs)
            pool_to_print_p_img = np.asarray(pool_to_print_p_img)
            pool_to_print_p_msk = np.asarray(pool_to_print_p_msk)
            pool_to_print_h_img = np.asarray(pool_to_print_h_img)
            pool_to_print_h_msk = np.asarray(pool_to_print_h_msk)
            print("pool_to_print_p_img: ", np.shape(pool_to_print_p_img))
            img_clb.on_epoch_end(epoch, pool_to_print_p_img, pool_to_print_p_msk,
                                 pool_to_print_h_img, pool_to_print_h_msk)
class ExtendedLogger(Callback):

    val_data_metrics = {}

    def __init__(self,
                 prediction_layer,
                 output_dir='./tmp',
                 stateful=False,
                 stateful_reset_interval=None,
                 starting_indicies=None):

        if stateful and stateful_reset_interval is None:
            raise ValueError(
                'If model is stateful, then seq-len has to be defined!')

        super(ExtendedLogger, self).__init__()

        self.csv_dir = os.path.join(output_dir, 'csv')
        self.tb_dir = os.path.join(output_dir, 'tensorboard')
        self.pred_dir = os.path.join(output_dir, 'predictions')
        self.plot_dir = os.path.join(output_dir, 'plots')

        make_dir(self.csv_dir)
        make_dir(self.tb_dir)
        make_dir(self.plot_dir)
        make_dir(self.pred_dir)

        self.stateful = stateful
        self.stateful_reset_interval = stateful_reset_interval
        self.starting_indicies = starting_indicies
        self.csv_logger = CSVLogger(os.path.join(self.csv_dir, 'run.csv'))
        self.tensorboard = TensorBoard(log_dir=self.tb_dir, write_graph=True)
        self.prediction_layer = prediction_layer

    def set_params(self, params):
        super(ExtendedLogger, self).set_params(params)
        self.tensorboard.set_params(params)
        self.tensorboard.batch_size = params['batch_size']
        self.csv_logger.set_params(params)

    def set_model(self, model):
        super(ExtendedLogger, self).set_model(model)
        self.tensorboard.set_model(model)
        self.csv_logger.set_model(model)

    def on_batch_begin(self, batch, logs=None):
        self.csv_logger.on_batch_begin(batch, logs=logs)
        self.tensorboard.on_batch_begin(batch, logs=logs)

    def on_batch_end(self, batch, logs=None):
        self.csv_logger.on_batch_end(batch, logs=logs)
        self.tensorboard.on_batch_end(batch, logs=logs)

    def on_train_begin(self, logs=None):
        self.csv_logger.on_train_begin(logs=logs)
        self.tensorboard.on_train_begin(logs=logs)

    def on_train_end(self, logs=None):
        self.csv_logger.on_train_end(logs=logs)
        self.tensorboard.on_train_end(logs)

    def on_epoch_begin(self, epoch, logs=None):
        self.csv_logger.on_epoch_begin(epoch, logs=logs)
        self.tensorboard.on_epoch_begin(epoch, logs=logs)

    def on_epoch_end(self, epoch, logs=None):

        with timeit('metrics'):

            outputs = self.model.get_layer(self.prediction_layer).output
            self.prediction_model = Model(inputs=self.model.input,
                                          outputs=outputs)

            batch_size = self.params['batch_size']

            if isinstance(self.validation_data[-1], float):
                val_data = self.validation_data[:-2]
            else:
                val_data = self.validation_data[:-1]

            y_true = val_data[1]

            callback = None
            if self.stateful:
                callback = ResetStatesCallback(
                    interval=self.stateful_reset_interval)
                callback.model = self.prediction_model

            y_pred = self.prediction_model.predict(val_data[:-1],
                                                   batch_size=batch_size,
                                                   verbose=1,
                                                   callback=callback)

            print(y_true.shape, y_pred.shape)

            self.write_prediction(epoch, y_true, y_pred)

            y_true = y_true.reshape((-1, 7))
            y_pred = y_pred.reshape((-1, 7))

            self.save_error_histograms(epoch, y_true, y_pred)
            self.save_topview_trajectories(epoch, y_true, y_pred)

            new_logs = {
                name: np.array(metric(y_true, y_pred))
                for name, metric in self.val_data_metrics.items()
            }
            logs.update(new_logs)

            homo_logs = self.try_add_homoscedastic_params()
            logs.update(homo_logs)

            self.tensorboard.validation_data = self.validation_data
            self.csv_logger.validation_data = self.validation_data

            self.tensorboard.on_epoch_end(epoch, logs=logs)
            self.csv_logger.on_epoch_end(epoch, logs=logs)

    def add_validation_metrics(self, metrics_dict):
        self.val_data_metrics.update(metrics_dict)

    def add_validation_metric(self, name, metric):
        self.val_data_metrics[name] = metric

    def try_add_homoscedastic_params(self):
        homo_pos_loss_layer = search_layer(self.model, 'homo_pos_loss')
        homo_quat_loss_layer = search_layer(self.model, 'homo_quat_loss')

        if homo_pos_loss_layer:
            homo_pos_log_vars = np.array(homo_pos_loss_layer.get_weights()[0])
            homo_quat_log_vars = np.array(
                homo_quat_loss_layer.get_weights()[0])
            return {
                'pos_log_var': np.array(homo_pos_log_vars),
                'quat_log_var': np.array(homo_quat_log_vars),
            }
        else:
            return {}

    def write_prediction(self, epoch, y_true, y_pred):
        filename = '{:04d}_predictions.npy'.format(epoch)
        filename = os.path.join(self.pred_dir, filename)
        arr = {'y_pred': y_pred, 'y_true': y_true}
        np.save(filename, arr)

    def save_topview_trajectories(self,
                                  epoch,
                                  y_true,
                                  y_pred,
                                  max_segment=1000):

        if self.starting_indicies is None:
            self.starting_indicies = {'valid': range(0, 4000, 1000) + [4000]}

        for begin, end in pairwise(self.starting_indicies['valid']):

            diff = end - begin
            if diff > max_segment:
                subindicies = range(begin, end, max_segment) + [end]
                for b, e in pairwise(subindicies):
                    self.save_trajectory(epoch, y_true, y_pred, b, e)

            self.save_trajectory(epoch, y_true, y_pred, begin, end)

    def save_trajectory(self, epoch, y_true, y_pred, begin, end):
        true_xy, pred_xy = y_true[begin:end, :2], y_pred[begin:end, :2]

        true_q = quaternion.as_quat_array(y_true[begin:end, [6, 3, 4, 5]])
        true_q = quaternion.as_euler_angles(true_q)[1]

        pred_q = quaternion.as_quat_array(y_pred[begin:end, [6, 3, 4, 5]])
        pred_q = quaternion.as_euler_angles(pred_q)[1]

        plt.clf()

        plt.plot(true_xy[:, 0], true_xy[:, 1], 'g-')
        plt.plot(pred_xy[:, 0], pred_xy[:, 1], 'r-')

        for ((x1, y1), (x2, y2)) in zip(true_xy, pred_xy):
            plt.plot([x1, x2], [y1, y2],
                     color='k',
                     linestyle='-',
                     linewidth=0.3,
                     alpha=0.2)

        plt.grid(True)
        plt.xlabel('x [m]')
        plt.ylabel('y [m]')
        plt.title('Top-down view of trajectory')
        plt.axis('equal')

        x_range = (np.min(true_xy[:, 0]) - .2, np.max(true_xy[:, 0]) + .2)
        y_range = (np.min(true_xy[:, 1]) - .2, np.max(true_xy[:, 1]) + .2)

        plt.xlim(x_range)
        plt.ylim(y_range)

        filename = 'epoch={epoch:04d}_begin={begin:04d}_end={end:04d}_trajectory.pdf' \
          .format(epoch=epoch, begin=begin, end=end)
        filename = os.path.join(self.plot_dir, filename)
        plt.savefig(filename)

    def save_error_histograms(self, epoch, y_true, y_pred):
        pos_errors = PoseMetrics.abs_errors_position(y_true, y_pred)
        pos_errors = np.sort(pos_errors)

        angle_errors = PoseMetrics.abs_errors_orienation(y_true, y_pred)
        angle_errors = np.sort(angle_errors)

        size = len(y_true)
        ys = np.arange(size) / float(size)

        plt.clf()

        plt.subplot(2, 1, 1)
        plt.title('Empirical CDF of absolute errors')
        plt.grid(True)
        plt.plot(pos_errors, ys, 'k-')
        plt.xlabel('Absolute Position Error (m)')
        plt.xlim(0, 1.2)

        plt.subplot(2, 1, 2)
        plt.grid(True)
        plt.plot(angle_errors, ys, 'r-')
        plt.xlabel('Absolute Angle Error (deg)')
        plt.xlim(0, 70)

        filename = '{:04d}_cdf.pdf'.format(epoch)
        filename = os.path.join(self.plot_dir, filename)
        plt.savefig(filename)
Example #8
0
    def fit(self):
        """
        Train SDNet
        """
        log.info('Training SDNet')

        # Load data
        self.init_train()

        # Initialise callbacks
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()
        si = SDNetCallback(self.conf.folder, self.conf.batch_size, self.sdnet)
        es = EarlyStopping('val_loss', min_delta=0.001, patience=20)
        es.on_train_begin()

        loss_names = [
            'adv_M', 'adv_X', 'rec_X', 'rec_M', 'rec_Z', 'dis_M', 'dis_X',
            'mask', 'image', 'val_loss'
        ]

        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size)

        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            real_lb_pool, real_ul_pool = [], [
            ]  # these are used only for printing images

            epoch_loss = {n: [] for n in loss_names}

            D_initial_weights = np.mean(
                [np.mean(w) for w in self.sdnet.D_model.get_weights()])
            G_initial_weights = np.mean(
                [np.mean(w) for w in self.sdnet.G_model.get_weights()])
            for self.batch in range(self.conf.batches):
                real_lb = next(self.gen_X_L)
                real_ul = next(self.gen_X_U)

                # Add image/mask batch to the data pool
                x, m = real_lb
                real_lb_pool.extend([(x[i:i + 1], m[i:i + 1])
                                     for i in range(x.shape[0])])
                real_ul_pool.extend(real_ul)

                D_weights1 = np.mean(
                    [np.mean(w) for w in self.sdnet.D_model.get_weights()])
                self.train_batch_generator(real_lb, real_ul, epoch_loss)
                D_weights2 = np.mean(
                    [np.mean(w) for w in self.sdnet.D_model.get_weights()])
                assert D_weights1 == D_weights2

                self.train_batch_discriminator(real_lb, real_ul, epoch_loss)

                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            G_final_weights = np.mean(
                [np.mean(w) for w in self.sdnet.G_model.get_weights()])
            D_final_weights = np.mean(
                [np.mean(w) for w in self.sdnet.D_model.get_weights()])

            # Check training is altering weights
            assert D_initial_weights != D_final_weights
            assert G_initial_weights != G_final_weights

            # Plot some example images
            si.on_epoch_end(self.epoch, np.array(real_lb_pool),
                            np.array(real_ul_pool))

            self.validate(epoch_loss)

            # Calculate epoch losses
            for n in loss_names:
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % \
                  ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}
            sl.on_epoch_end(self.epoch, logs)

            # log losses to csv
            cl.model = self.sdnet.D_model
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)

            # save models
            self.sdnet.save_models()

            # early stopping
            if self.stop_criterion(es, self.epoch, logs):
                log.info('Finished training from early stopping criterion')
                break
Example #9
0
    def train(self):
        self.init_train_data()
        # make genetrated data
        gen_dict = self.get_datagen_params()

        # Here we need to concatenate age and AD labels, in order to use Function ImageDataGenerator
        yng_labels = np.concatenate([self.train_age_yng, self.train_AD_yng],
                                    axis=1)
        old_labels = np.concatenate([self.train_age_old, self.train_AD_old],
                                    axis=1)

        old_gen = ImageDataGenerator(**gen_dict).flow(
            x=self.train_img_old,
            y=old_labels,
            batch_size=self.conf.batch_size)
        yng_gen = ImageDataGenerator(**gen_dict).flow(
            x=self.train_img_yng,
            y=yng_labels,
            batch_size=self.conf.batch_size)

        # initialize training
        batches = int(np.ceil(self.conf.data_len / self.conf.batch_size))
        progress_bar = Progbar(target=batches * self.conf.batch_size)

        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()
        img_clb = ImageCallback(self.conf, self.model, self.comet_exp)

        # clr = CyclicLR(base_lr=self.conf.lr/5, max_lr=self.conf.lr,
        #                step_size=batches*4, mode='triangular')

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        # start training
        for epoch in range(self.conf.epochs):
            log.info("Train epoch %d/%d" % (epoch, self.conf.epochs))
            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []
            pool_to_print_old, pool_to_print_yng = [], []

            for batch in range(batches):
                old_img, old_labels = next(old_gen)
                yng_img, yng_labels = next(yng_gen)

                # Return labels to age and AD vectors
                old_age = old_labels[:, :self.conf.age_dim, :]
                old_AD = old_labels[:, self.conf.age_dim:, :]

                yng_age = yng_labels[:, :self.conf.age_dim, :]
                yng_AD = yng_labels[:, self.conf.age_dim:, :]

                if len(pool_to_print_old) < 30:
                    pool_to_print_old.append(old_img)

                if len(pool_to_print_yng) < 30:
                    pool_to_print_yng.append(yng_img)

                # Adversarial ground truths
                real_pred = -np.ones((old_img.shape[0], 1))
                fake_pred = np.ones((old_img.shape[0], 1))
                dummy = np.zeros((old_img.shape[0], 1))
                dummy_Img = np.ones(old_img.shape)
                # ---------------------
                #  Train Discriminator
                # ---------------------
                age_gap = calculate_age_diff(yng_age, old_age)
                diff_age = get_age_ord_vector(age_gap,
                                              expand_dim=1,
                                              con=self.conf.age_con,
                                              ord=self.conf.age_ord,
                                              age_dim=self.conf.age_dim)
                # Get a group of synthetic msks and imgs
                gen_masks = self.model.generator.predict(
                    [yng_img, diff_age, old_AD])
                gen_old_img = np.tanh(
                    gen_masks +
                    yng_img) if self.conf.use_tanh else gen_masks + yng_img
                # Need to train discriminators more iterations:
                if epoch < 25:
                    for _ in range(self.conf.ncritic[0]):
                        epsilon = np.random.uniform(0,
                                                    1,
                                                    size=(old_img.shape[0], 1,
                                                          1, 1))
                        interpolation = epsilon * old_img + (
                            1 - epsilon) * gen_old_img
                        h_d = self.model.critic_model.fit([
                            old_img, old_age, old_AD, gen_old_img, old_age,
                            old_AD, interpolation, old_age, old_AD
                        ], [real_pred, fake_pred, dummy],
                                                          epochs=1,
                                                          verbose=0)
                        # , callbacks=[clr])
                    # d_loss_bce = np.mean([h_real.history['binary_crossentropy'], h_fake.history['binary_crossentropy']])
                else:
                    for _ in range(self.conf.ncritic[1]):
                        epsilon = np.random.uniform(0,
                                                    1,
                                                    size=(old_img.shape[0], 1,
                                                          1, 1))
                        interpolation = epsilon * old_img + (
                            1 - epsilon) * gen_old_img
                        h_d = self.model.critic_model.fit([
                            old_img, old_age, old_AD, gen_old_img, old_age,
                            old_AD, interpolation, old_age, old_AD
                        ], [real_pred, fake_pred, dummy],
                                                          epochs=1,
                                                          verbose=0)
                        # , callbacks=[clr])

                # d_loss_bce = np.mean(h_real.history['d_loss'])
                print('d_real_loss', np.mean(h_d.history['d_real_loss']),
                      'd_fake_loss', np.mean(h_d.history['d_fake_loss']))
                d_loss_bce = np.mean(
                    [h_d.history['d_real_loss'], h_d.history['d_fake_loss']])
                d_loss_real = np.mean(h_d.history['d_real_loss'])
                d_loss_fake = np.mean(h_d.history['d_fake_loss'])
                d_loss_gp = np.mean(h_d.history['gp_loss'])
                epoch_loss['Discriminator_loss'].append(d_loss_bce)
                epoch_loss['Discriminator_real_loss'].append(d_loss_real)
                epoch_loss['Discriminator_fake_loss'].append(d_loss_fake)
                epoch_loss['Discriminator_gp_loss'].append(d_loss_gp)
                # --------------------
                #  Train Generator
                # --------------------
                # Train the generator, want discriminator to mistake images as real
                h = self.model.gan.fit(
                    [yng_img, old_age, diff_age, age_gap, old_AD],
                    [real_pred, dummy_Img],
                    epochs=1,
                    verbose=0)
                # , callbacks=[clr])
                # print(h.history)
                g_loss_bce = h.history['discriminator_loss']
                g_loss_l1 = h.history['map_l1_reg_loss']

                # Deal with epoch loss

                epoch_loss['Generator_fake_loss'].append(g_loss_bce)
                epoch_loss['Generator_l1_reg_loss'].append(g_loss_l1)

                #-----------------------------------------
                # Train Generator by self-regularization
                #-----------------------------------------
                diff_age_zero = yng_age - yng_age
                h = self.model.GAN_zero_reg([yng_img, diff_age_zero, yng_AD],
                                            yng_img,
                                            epochs=1,
                                            verbose=0)
                g_zero_reg = np.mean(h.history['self_reg'])
                epoch_loss['Generator_zero_gre_loss'].append(g_zero_reg)

                # Plot the progress
                progress_bar.update((batch + 1) * self.conf.batch_size)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))

            log.info(
                str('Epoch %d/%d: ' +
                    ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                ((epoch, self.conf.epochs) + tuple(total_loss[l][-1]
                                                   for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.discriminator
            cl.model.stop_training = False
            cl.on_epoch_end(epoch, logs)
            sl.on_epoch_end(epoch, logs)
            img_clb.on_epoch_end(epoch, yng_img, yng_age, old_img, old_age)
Example #10
0
    def train(self):
        def _learning_rate_schedule(epoch):
            return self.conf.lr * math.exp(self.lr_schedule_coef * (-epoch - 1))

        if os.path.exists(os.path.join(self.conf.folder, 'test-performance.csv')):
            os.remove(os.path.join(self.conf.folder, 'test-performance.csv'))
        if os.path.exists(os.path.join(self.conf.folder, 'validation-performance.csv')):
            os.remove(os.path.join(self.conf.folder, 'validation-performance.csv'))

        log.info('Training Model')
        dice_record = 0
        self.eval_train_interval = int(max(1, self.conf.epochs/50))

        self.init_train_data()
        lr_callback = LearningRateScheduler(_learning_rate_schedule)

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('Validate_Dice', self.conf.min_delta, self.conf.patience)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        loss_names.sort()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches)
        # self.img_clb.on_epoch_end(self.epoch)

        best_performance = 0.
        test_performance = 0.
        total_iters = 0
        for self.epoch in range(self.conf.epochs):
            total_iters+=1
            log.info('Epoch %d/%d' % (self.epoch+1, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.conf.batches):
                total_iters += 1
                self.train_batch(epoch_loss, lr_callback)
                progress_bar.update(self.batch + 1)

            val_dice = self.validate(epoch_loss)
            if val_dice > dice_record:
                dice_record = val_dice

            cl.model = self.model.D_Reconstruction
            cl.model.stop_training = False

            self.model.save_models()

            # Plot some example images
            if self.epoch % self.eval_train_interval == 0 or self.epoch == self.conf.epochs - 1:
                self.img_clb.on_epoch_end(self.epoch)
                folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'),
                                      'test_results_%s_epoch%d'
                                      % (self.conf.test_dataset, self.epoch))
                if not os.path.exists(folder):
                    os.makedirs(folder)
                test_performance = self.test_modality(folder, self.conf.modality, 'test', False)
                if test_performance > best_performance:
                    best_performance = test_performance
                    self.model.save_models('BestModel')
                    log.info("BestModel@Epoch%d" % self.epoch)

                folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'),
                                      'validation_results_%s_epoch%d'
                                      % (self.conf.test_dataset, self.epoch))
                if not os.path.exists(folder):
                    os.makedirs(folder)
                validation_performance = self.test_modality(folder, self.conf.modality, 'validation', False)
                if self.conf.batches>check_batch_iters:
                    self.write_csv(os.path.join(self.conf.folder, 'test-performance.csv'),
                                   self.epoch, self.batch, test_performance)
                    self.write_csv(os.path.join(self.conf.folder, 'validation-performance.csv'),
                                   self.epoch, self.batch, validation_performance)
            epoch_loss['Test_Performance_Dice'].append(test_performance)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))

            if self.epoch<5:
                log.info(str('Epoch %d/%d:\n' + ''.join([l + ' Loss = %.3f\n' for l in loss_names])) %
                         ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            else:
                info_str = str('Epoch %d/%d:\n' % (self.epoch, self.conf.epochs))
                loss_info = ''
                for l in loss_names:
                    loss_info = loss_info + l + ' Loss = %.3f->%.3f->%.3f->%.3f->%.3f\n' % \
                                (total_loss[l][-5],
                                 total_loss[l][-4],
                                 total_loss[l][-3],
                                 total_loss[l][-2],
                                 total_loss[l][-1])
                log.info(info_str + loss_info)
            log.info("BestTest:%f" % best_performance)
            log.info('Epoch %d/%d' % (self.epoch + 1, self.conf.epochs))
            logs = {l: total_loss[l][-1] for l in loss_names}
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            if self.stop_criterion(es, logs) and self.epoch > self.conf.epochs / 2:
                log.info('Finished training from early stopping criterion')
                self.img_clb.on_epoch_end(self.epoch)
                break
Example #11
0
def train_model(model, data, config, include_tensorboard):
	model_history = History()
	model_history.on_train_begin()
	saver = ModelCheckpoint(full_path(config.model_file()), verbose=1, save_best_only=True, period=1)
	saver.set_model(model)
	early_stopping = EarlyStopping(min_delta=config.min_delta, patience=config.patience, verbose=1)
	early_stopping.set_model(model)
	early_stopping.on_train_begin()
	csv_logger = CSVLogger(full_path(config.csv_log_file()))
	csv_logger.on_train_begin()
	if include_tensorboard:
		tensorborad = TensorBoard(histogram_freq=10, write_images=True)
		tensorborad.set_model(model)
	else:
	 tensorborad = Callback()

	epoch = 0
	stop = False
	while(epoch <= config.max_epochs and stop == False):
		epoch_history = History()
		epoch_history.on_train_begin()
		valid_sizes = []
		train_sizes = []
		print("Epoch:", epoch)
		for dataset in data.datasets:
			print("dataset:", dataset.name)
			model.reset_states()
			dataset.reset_generators()

			valid_sizes.append(dataset.valid_generators[0].size())
			train_sizes.append(dataset.train_generators[0].size())
			fit_history = model.fit_generator(dataset.train_generators[0],
				dataset.train_generators[0].size(), 
				nb_epoch=1, 
				verbose=0, 
				validation_data=dataset.valid_generators[0], 
				nb_val_samples=dataset.valid_generators[0].size())

			epoch_history.on_epoch_end(epoch, last_logs(fit_history))

			train_sizes.append(dataset.train_generators[1].size())
			fit_history = model.fit_generator(dataset.train_generators[1],
				dataset.train_generators[1].size(),
				nb_epoch=1, 
				verbose=0)

			epoch_history.on_epoch_end(epoch, last_logs(fit_history))

		epoch_logs = average_logs(epoch_history, train_sizes, valid_sizes)
		model_history.on_epoch_end(epoch, logs=epoch_logs)
		saver.on_epoch_end(epoch, logs=epoch_logs)
		early_stopping.on_epoch_end(epoch, epoch_logs)
		csv_logger.on_epoch_end(epoch, epoch_logs)
		tensorborad.on_epoch_end(epoch, epoch_logs)
		epoch+= 1

		if early_stopping.stopped_epoch > 0:
			stop = True

	early_stopping.on_train_end()
	csv_logger.on_train_end()
	tensorborad.on_train_end({})
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss', min_delta=0.01, patience=100)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size)

        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            D_initial_weights = np.mean(
                [np.mean(w) for w in self.model.D_trainer.get_weights()])
            G_initial_weights = np.mean(
                [np.mean(w) for w in self.model.G_trainer.get_weights()])
            for self.batch in range(self.conf.batches):
                # real_pools = self.add_to_pool(data, real_pools)
                self.train_batch(epoch_loss)

                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            G_final_weights = np.mean(
                [np.mean(w) for w in self.model.G_trainer.get_weights()])
            D_final_weights = np.mean(
                [np.mean(w) for w in self.model.D_trainer.get_weights()])

            assert self.gen_unlabelled is None or not self.model.D_trainer.trainable \
                   or D_initial_weights != D_final_weights
            assert G_initial_weights != G_final_weights

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(
                str('Epoch %d/%d: ' +
                    ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1]
                                                        for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # Plot some example images
            self.img_clb.on_epoch_end(self.epoch)

            self.model.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')
                break