Example #1
0
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.batches * self.conf.batch_size)
        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.batches):
                self.train_batch(epoch_loss)
                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                     ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # Plot some example images
            self.img_callback.on_epoch_end(self.epoch)

            self.model.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')
                break
def main( dirs, z_size=32, batch_size=100, learning_rate=0.0001, kl_tolerance=0.5, epochs=100, save_model=False, verbose=True, optimizer="Adam" ):

    if save_model:
        model_save_path = "tf_vae"
        if not os.path.exists(model_save_path):
          os.makedirs(model_save_path)

    gen = DriveDataGenerator(dirs, image_size=(64,64), batch_size=batch_size, shuffle=True, max_load=10000, images_only=True )
        
    num_batches = len(gen)

    reset_graph()

    vae = ConvVAE(z_size=z_size,
                  batch_size=batch_size,
                  learning_rate=learning_rate,
                  kl_tolerance=kl_tolerance,
                  is_training=True,
                  reuse=False,
                  gpu_mode=True,
                  optimizer=optimizer)

    early = EarlyStopping(monitor='loss', min_delta=0.1, patience=5, verbose=verbose, mode='auto')
    early.set_model(vae)
    early.on_train_begin()

    best_loss = sys.maxsize

    if verbose:
        print("epoch\tstep\tloss\trecon_loss\tkl_loss")
    for epoch in range(epochs):
        for idx in range(num_batches):
            batch = gen[idx]

            obs = batch.astype(np.float)/255.0

            feed = {vae.x: obs,}

            (train_loss, r_loss, kl_loss, train_step, _) = vae.sess.run([
              vae.loss, vae.r_loss, vae.kl_loss, vae.global_step, vae.train_op
            ], feed)
            
            if train_loss < best_loss:
                best_loss = train_loss

            if save_model:
                if ((train_step+1) % 5000 == 0):
                  vae.save_json("tf_vae/vae.json")
        if verbose:
            print("{} of {}\t{}\t{:.2f}\t{:.2f}\t{:.2f}".format( epoch, epochs, (train_step+1), train_loss, r_loss, kl_loss) )
        gen.on_epoch_end()
        early.on_epoch_end(epoch, logs={"loss": train_loss})
        if vae.stop_training:
            break
    early.on_train_end()


# finished, final model:
    if save_model:
        vae.save_json("tf_vae/vae.json")

    return best_loss
Example #3
0
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss_mod2_fused', min_delta=0.01, patience=60)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.batches * self.conf.batch_size)
        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.batches):
                self.train_batch(epoch_loss)
                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            self.set_swa_model_weights()
            for swa_m in self.get_swa_models():
                swa_m.on_epoch_end(self.epoch)

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(
                str('Epoch %d/%d: ' +
                    ', '.join([l + ' Loss = %.5f' for l in loss_names])) %
                ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1]
                                                        for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # print images
            self.img_callback.on_epoch_end(self.epoch)

            self.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')

                es.on_train_end(logs)
                cl.on_train_end(logs)
                for swa_m in self.get_swa_models():
                    swa_m.on_train_end()

                # Set final model parameters based on SWA
                self.model.D_Mask = self.swa_D_Mask.model
                self.model.D_Image1 = self.swa_D_Image1.model
                self.model.D_Image2 = self.swa_D_Image2.model
                self.model.Encoders_Anatomy[0] = self.swa_Enc_Anatomy1.model
                self.model.Encoders_Anatomy[1] = self.swa_Enc_Anatomy2.model
                self.model.Enc_Modality = self.swa_Enc_Modality.model
                self.model.Anatomy_Fuser = self.swa_Anatomy_Fuser.model
                self.model.Segmentor = self.swa_Segmentor.model
                self.model.Decoder = self.swa_Decoder.model
                self.model.Balancer = self.swa_Balancer.model

                self.save_models()
                break
Example #4
0
    redlr = ReduceLROnPlateau(monitor='loss_va',
                              factor=conf['red_lr_factor'],
                              patience=conf['red_lr_patience'],
                              verbose=True,
                              mode='min',
                              min_delta=conf['red_lr_eps'],
                              min_lr=conf['red_lr_min_lr'])
    redlr.model = snmt.model
if conf['early_stopping']:
    earlstop = EarlyStopping(monitor='loss_va',
                             min_delta=conf['early_stopping_eps'],
                             patience=conf['early_stopping_patience'],
                             verbose=True,
                             mode='min')
    earlstop.model = snmt.model
    earlstop.on_train_begin()

# Prepare savers.
saver = callbacks.TrainStateSaver(path_trrun,
                                  model=snmt.model,
                                  optimizer=optimizer,
                                  verbose=True)

# Create data loaders.
prn = (None, conf['path_normals'])[conf['normals_stream']]
prd = (None, conf['path_dmaps'])[conf['depth_stream']]
prm = (None, conf['path_meshes'])[conf['mesh_stream']]
tf_dm = TfReshape(tuple(conf['input_shape'][:2]) + (1, ))

ds_tr = DatasetImgNDM(conf['path_imgs'],
                      conf['seqs_tr'],
Example #5
0
    def train(self,
              train,
              test,
              max_iterations=2000,
              epochs_per_iter=1,
              val_split=0.1,
              val_data=None,
              minibatch_size=32):

        if val_data:
            ln.debug("%s val samples were supplied" % len(val_data[1]))
            X_inputs, Y = train
            X_val_in, Y_val = val_data
        else:
            ln.debug("Splitting train set to get val dat")
            split_idx = int(train[1].shape[0] * (1. - val_split))
            X_inputs_all, Y_all = train

            X_inputs = [x_input[:split_idx] for x_input in X_inputs_all]
            X_val_in = [x_input[split_idx:] for x_input in X_inputs_all]

            Y, Y_val = Y_all[:split_idx], Y_all[split_idx:]

        tX_sents, tY = test

        save_fname = "charmodel%s.h5" % int(time.time())

        checkpoint = ModelCheckpoint(save_fname,
                                     monitor="val_acc",
                                     save_best_only=True,
                                     save_weights_only=True,
                                     verbose=1)

        lr_scheduler = ReduceLROnPlateau(monitor="val_acc",
                                         factor=0.15,
                                         patience=3,
                                         verbose=1,
                                         min_lr=0.00001)
        lr_scheduler.on_train_begin()
        lr_scheduler.on_train_begin = lambda logs: None

        early_stopping = EarlyStopping(monitor="val_acc", patience=6)
        early_stopping.on_train_begin()
        early_stopping.on_train_begin = lambda logs: None

        if self.model is None:
            self.build_model()

        try:
            for iteration in range(max_iterations):
                ln.info("Starting iteration %s/%s" %
                        (iteration + 1, max_iterations))
                start = time.time()
                hist = self.model.fit(
                    X_inputs,
                    Y,
                    #batch_size=minibatch_size, epochs=epochs_per_iter,
                    minibatch_size,
                    epochs_per_iter,
                    validation_data=(X_val_in, Y_val),
                    callbacks=[checkpoint, lr_scheduler, early_stopping],
                    verbose=2)

                if early_stopping.model.stop_training:
                    ln.debug("Early stopping triggered, stopping training.")
                    break

                X_sample_inputs = [X_input[:250] for X_input in X_inputs]
                val_sample_inputs = X_val_in

                ln.debug("Evaluating train and val set..")
                train_score = self.evaluate_zero_one(X_sample_inputs,
                                                     Y[:250],
                                                     debug_print=5)
                val_score = self.evaluate_zero_one(val_sample_inputs,
                                                   Y_val,
                                                   debug_print=8)

                ln.info("Current train set accuracy: %s" % train_score, )
                ln.info("Current validation accuracy: %s" % val_score)
                ln.debug("History: %s" % (hist.history, ))
                ln.debug("Epoch took %s seconds" % (time.time() - start))

                # Shuffle data
                all_inputs = shuffle_together(X_inputs + [Y])
                X_inputs, Y = all_inputs[:-1], all_inputs[-1]

        except KeyboardInterrupt:
            ln.info("Interrupted, stopping training!")

        self.model.load_weights(save_fname)

        # ln.info("Trained for %s iterations, saving.." % iteration)
        loss, acc = self.model.evaluate(tX_sents, tY, batch_size=128)
        zero_one = self.evaluate_zero_one(tX_sents, tY, debug_print=10)

        ln.debug("Final test loss: %s. Eval acc: %s" % (loss, acc))
        ln.debug("Final test zero-one accuracy: %s" % zero_one)
        return loss, acc, zero_one
Example #6
0
    def fit(self):
        """
        Train SDNet
        """
        log.info('Training SDNet')

        # Load data
        self.init_train()

        # Initialise callbacks
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()
        si = SDNetCallback(self.conf.folder, self.conf.batch_size, self.sdnet)
        es = EarlyStopping('val_loss', min_delta=0.001, patience=20)
        es.on_train_begin()

        loss_names = [
            'adv_M', 'adv_X', 'rec_X', 'rec_M', 'rec_Z', 'dis_M', 'dis_X',
            'mask', 'image', 'val_loss'
        ]

        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size)

        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            real_lb_pool, real_ul_pool = [], [
            ]  # these are used only for printing images

            epoch_loss = {n: [] for n in loss_names}

            D_initial_weights = np.mean(
                [np.mean(w) for w in self.sdnet.D_model.get_weights()])
            G_initial_weights = np.mean(
                [np.mean(w) for w in self.sdnet.G_model.get_weights()])
            for self.batch in range(self.conf.batches):
                real_lb = next(self.gen_X_L)
                real_ul = next(self.gen_X_U)

                # Add image/mask batch to the data pool
                x, m = real_lb
                real_lb_pool.extend([(x[i:i + 1], m[i:i + 1])
                                     for i in range(x.shape[0])])
                real_ul_pool.extend(real_ul)

                D_weights1 = np.mean(
                    [np.mean(w) for w in self.sdnet.D_model.get_weights()])
                self.train_batch_generator(real_lb, real_ul, epoch_loss)
                D_weights2 = np.mean(
                    [np.mean(w) for w in self.sdnet.D_model.get_weights()])
                assert D_weights1 == D_weights2

                self.train_batch_discriminator(real_lb, real_ul, epoch_loss)

                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            G_final_weights = np.mean(
                [np.mean(w) for w in self.sdnet.G_model.get_weights()])
            D_final_weights = np.mean(
                [np.mean(w) for w in self.sdnet.D_model.get_weights()])

            # Check training is altering weights
            assert D_initial_weights != D_final_weights
            assert G_initial_weights != G_final_weights

            # Plot some example images
            si.on_epoch_end(self.epoch, np.array(real_lb_pool),
                            np.array(real_ul_pool))

            self.validate(epoch_loss)

            # Calculate epoch losses
            for n in loss_names:
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(str('Epoch %d/%d: ' + ', '.join([l + ' Loss = %.3f' for l in loss_names])) % \
                  ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}
            sl.on_epoch_end(self.epoch, logs)

            # log losses to csv
            cl.model = self.sdnet.D_model
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)

            # save models
            self.sdnet.save_models()

            # early stopping
            if self.stop_criterion(es, self.epoch, logs):
                log.info('Finished training from early stopping criterion')
                break
Example #7
0
    def train(self):
        def _learning_rate_schedule(epoch):
            return self.conf.lr * math.exp(self.lr_schedule_coef * (-epoch - 1))

        if os.path.exists(os.path.join(self.conf.folder, 'test-performance.csv')):
            os.remove(os.path.join(self.conf.folder, 'test-performance.csv'))
        if os.path.exists(os.path.join(self.conf.folder, 'validation-performance.csv')):
            os.remove(os.path.join(self.conf.folder, 'validation-performance.csv'))

        log.info('Training Model')
        dice_record = 0
        self.eval_train_interval = int(max(1, self.conf.epochs/50))

        self.init_train_data()
        lr_callback = LearningRateScheduler(_learning_rate_schedule)

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('Validate_Dice', self.conf.min_delta, self.conf.patience)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        loss_names.sort()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches)
        # self.img_clb.on_epoch_end(self.epoch)

        best_performance = 0.
        test_performance = 0.
        total_iters = 0
        for self.epoch in range(self.conf.epochs):
            total_iters+=1
            log.info('Epoch %d/%d' % (self.epoch+1, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            for self.batch in range(self.conf.batches):
                total_iters += 1
                self.train_batch(epoch_loss, lr_callback)
                progress_bar.update(self.batch + 1)

            val_dice = self.validate(epoch_loss)
            if val_dice > dice_record:
                dice_record = val_dice

            cl.model = self.model.D_Reconstruction
            cl.model.stop_training = False

            self.model.save_models()

            # Plot some example images
            if self.epoch % self.eval_train_interval == 0 or self.epoch == self.conf.epochs - 1:
                self.img_clb.on_epoch_end(self.epoch)
                folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'),
                                      'test_results_%s_epoch%d'
                                      % (self.conf.test_dataset, self.epoch))
                if not os.path.exists(folder):
                    os.makedirs(folder)
                test_performance = self.test_modality(folder, self.conf.modality, 'test', False)
                if test_performance > best_performance:
                    best_performance = test_performance
                    self.model.save_models('BestModel')
                    log.info("BestModel@Epoch%d" % self.epoch)

                folder = os.path.join(os.path.join(self.conf.folder, 'test_during_train'),
                                      'validation_results_%s_epoch%d'
                                      % (self.conf.test_dataset, self.epoch))
                if not os.path.exists(folder):
                    os.makedirs(folder)
                validation_performance = self.test_modality(folder, self.conf.modality, 'validation', False)
                if self.conf.batches>check_batch_iters:
                    self.write_csv(os.path.join(self.conf.folder, 'test-performance.csv'),
                                   self.epoch, self.batch, test_performance)
                    self.write_csv(os.path.join(self.conf.folder, 'validation-performance.csv'),
                                   self.epoch, self.batch, validation_performance)
            epoch_loss['Test_Performance_Dice'].append(test_performance)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))

            if self.epoch<5:
                log.info(str('Epoch %d/%d:\n' + ''.join([l + ' Loss = %.3f\n' for l in loss_names])) %
                         ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1] for l in loss_names)))
            else:
                info_str = str('Epoch %d/%d:\n' % (self.epoch, self.conf.epochs))
                loss_info = ''
                for l in loss_names:
                    loss_info = loss_info + l + ' Loss = %.3f->%.3f->%.3f->%.3f->%.3f\n' % \
                                (total_loss[l][-5],
                                 total_loss[l][-4],
                                 total_loss[l][-3],
                                 total_loss[l][-2],
                                 total_loss[l][-1])
                log.info(info_str + loss_info)
            log.info("BestTest:%f" % best_performance)
            log.info('Epoch %d/%d' % (self.epoch + 1, self.conf.epochs))
            logs = {l: total_loss[l][-1] for l in loss_names}
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            if self.stop_criterion(es, logs) and self.epoch > self.conf.epochs / 2:
                log.info('Finished training from early stopping criterion')
                self.img_clb.on_epoch_end(self.epoch)
                break
Example #8
0
def train_model(model, data, config, include_tensorboard):
	model_history = History()
	model_history.on_train_begin()
	saver = ModelCheckpoint(full_path(config.model_file()), verbose=1, save_best_only=True, period=1)
	saver.set_model(model)
	early_stopping = EarlyStopping(min_delta=config.min_delta, patience=config.patience, verbose=1)
	early_stopping.set_model(model)
	early_stopping.on_train_begin()
	csv_logger = CSVLogger(full_path(config.csv_log_file()))
	csv_logger.on_train_begin()
	if include_tensorboard:
		tensorborad = TensorBoard(histogram_freq=10, write_images=True)
		tensorborad.set_model(model)
	else:
	 tensorborad = Callback()

	epoch = 0
	stop = False
	while(epoch <= config.max_epochs and stop == False):
		epoch_history = History()
		epoch_history.on_train_begin()
		valid_sizes = []
		train_sizes = []
		print("Epoch:", epoch)
		for dataset in data.datasets:
			print("dataset:", dataset.name)
			model.reset_states()
			dataset.reset_generators()

			valid_sizes.append(dataset.valid_generators[0].size())
			train_sizes.append(dataset.train_generators[0].size())
			fit_history = model.fit_generator(dataset.train_generators[0],
				dataset.train_generators[0].size(), 
				nb_epoch=1, 
				verbose=0, 
				validation_data=dataset.valid_generators[0], 
				nb_val_samples=dataset.valid_generators[0].size())

			epoch_history.on_epoch_end(epoch, last_logs(fit_history))

			train_sizes.append(dataset.train_generators[1].size())
			fit_history = model.fit_generator(dataset.train_generators[1],
				dataset.train_generators[1].size(),
				nb_epoch=1, 
				verbose=0)

			epoch_history.on_epoch_end(epoch, last_logs(fit_history))

		epoch_logs = average_logs(epoch_history, train_sizes, valid_sizes)
		model_history.on_epoch_end(epoch, logs=epoch_logs)
		saver.on_epoch_end(epoch, logs=epoch_logs)
		early_stopping.on_epoch_end(epoch, epoch_logs)
		csv_logger.on_epoch_end(epoch, epoch_logs)
		tensorborad.on_epoch_end(epoch, epoch_logs)
		epoch+= 1

		if early_stopping.stopped_epoch > 0:
			stop = True

	early_stopping.on_train_end()
	csv_logger.on_train_end()
	tensorborad.on_train_end({})
    def train(self):
        log.info('Training Model')

        self.init_train_data()

        self.init_image_callback()
        sl = SaveLoss(self.conf.folder)
        cl = CSVLogger(self.conf.folder + '/training.csv')
        cl.on_train_begin()

        es = EarlyStopping('val_loss', min_delta=0.01, patience=100)
        es.model = self.model.Segmentor
        es.on_train_begin()

        loss_names = self.get_loss_names()
        total_loss = {n: [] for n in loss_names}

        progress_bar = Progbar(target=self.conf.batches * self.conf.batch_size)

        for self.epoch in range(self.conf.epochs):
            log.info('Epoch %d/%d' % (self.epoch, self.conf.epochs))

            epoch_loss = {n: [] for n in loss_names}
            epoch_loss_list = []

            D_initial_weights = np.mean(
                [np.mean(w) for w in self.model.D_trainer.get_weights()])
            G_initial_weights = np.mean(
                [np.mean(w) for w in self.model.G_trainer.get_weights()])
            for self.batch in range(self.conf.batches):
                # real_pools = self.add_to_pool(data, real_pools)
                self.train_batch(epoch_loss)

                progress_bar.update((self.batch + 1) * self.conf.batch_size)

            G_final_weights = np.mean(
                [np.mean(w) for w in self.model.G_trainer.get_weights()])
            D_final_weights = np.mean(
                [np.mean(w) for w in self.model.D_trainer.get_weights()])

            assert self.gen_unlabelled is None or not self.model.D_trainer.trainable \
                   or D_initial_weights != D_final_weights
            assert G_initial_weights != G_final_weights

            self.validate(epoch_loss)

            for n in loss_names:
                epoch_loss_list.append((n, np.mean(epoch_loss[n])))
                total_loss[n].append(np.mean(epoch_loss[n]))
            log.info(
                str('Epoch %d/%d: ' +
                    ', '.join([l + ' Loss = %.3f' for l in loss_names])) %
                ((self.epoch, self.conf.epochs) + tuple(total_loss[l][-1]
                                                        for l in loss_names)))
            logs = {l: total_loss[l][-1] for l in loss_names}

            cl.model = self.model.D_Mask
            cl.model.stop_training = False
            cl.on_epoch_end(self.epoch, logs)
            sl.on_epoch_end(self.epoch, logs)

            # Plot some example images
            self.img_clb.on_epoch_end(self.epoch)

            self.model.save_models()

            if self.stop_criterion(es, logs):
                log.info('Finished training from early stopping criterion')
                break