示例#1
0
 def on_epoch_end(self, epoch, logs=None):
     if self.save_best_only and self.monitor_op(logs.get(self.monitor), self.best):
         model_to_be_removed = self.last_saved_model
         self.last_saved_model = self.filepath.format(epoch=epoch, **logs)
     else:
         model_to_be_removed = None
     ModelCheckpoint.on_epoch_end(self, epoch=epoch, logs=logs)
     if model_to_be_removed is not None and model_to_be_removed != "":
         os.remove(model_to_be_removed)
示例#2
0
    def on_epoch_end(self, epoch, logs={}):
        histDict = self.histobj.history

        epoch = epoch + self.epochOffset + 1
        ModelCheckpoint.on_epoch_end(self, epoch, logs)
        histDict["last_epoch"] = epoch
        for k, v in logs.items():
            if k not in histDict:
                histDict[k] = []
            histDict[k].append(v)
        json.dump(histDict, open(self.historyFilename, "wb"))
示例#3
0
    def _gcp_on_epoch_end(self, epoch, logs=None):
        # Call original checkpoint to temporary file
        KerasModelCheckpoint.on_epoch_end(self, epoch, logs=logs)

        logs = logs or {}

        # Check if file exists and not empty
        if not os.path.exists(self.filepath):
            log.warning("Checkpoint file does not seem to exists. Ignoring")
            return

        if os.path.getsize(self.filepath) == 0:
            log.warning("File empty, no checkpoint has been saved")
            return

        final_path = self._original_filepath.format(epoch=epoch + 1, **logs)

        with file_io.FileIO(self.filepath, mode='rb') as input_f:
            with file_io.FileIO(final_path, mode='w+b') as output_f:
                output_f.write(input_f.read())

        # Remove local model
        os.remove(self.filepath)
示例#4
0
    def train_rtvsrgan(self,
                       epochs=None,
                       batch_size=None,
                       modelname=None,
                       datapath_train=None,
                       datapath_validation=None,
                       steps_per_validation=None,
                       datapath_test=None,
                       workers=None,
                       max_queue_size=None,
                       first_epoch=None,
                       print_frequency=None,
                       crops_per_image=None,
                       log_weight_frequency=None,
                       log_weight_path=None,
                       log_tensorboard_path=None,
                       log_tensorboard_update_freq=None,
                       log_test_frequency=None,
                       log_test_path=None,
                       media_type='i'):
        """Train the ESRGAN network
        :param int epochs: how many epochs to train the network for
        :param str modelname: name to use for storing model weights etc.
        :param str datapath_train: path for the image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved        
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        """

        # Create data loaders
        train_loader = DataLoader(datapath_train, batch_size, self.height_hr,
                                  self.width_hr, self.upscaling_factor,
                                  crops_per_image, media_type, self.channels,
                                  self.colorspace)

        # Validation data loader
        validation_loader = None
        if datapath_validation is not None:
            validation_loader = DataLoader(datapath_validation, batch_size,
                                           self.height_hr, self.width_hr,
                                           self.upscaling_factor,
                                           crops_per_image, media_type,
                                           self.channels, self.colorspace)

        test_loader = None
        if datapath_test is not None:
            test_loader = DataLoader(datapath_test, 1, self.height_hr,
                                     self.width_hr, self.upscaling_factor, 1,
                                     media_type, self.channels,
                                     self.colorspace)

        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(train_loader,
                                   use_multiprocessing=True,
                                   shuffle=True)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

        # Callback: save weights after each epoch
        modelcheckpoint = ModelCheckpoint(os.path.join(
            log_weight_path,
            modelname + '2_{}X.h5'.format(self.upscaling_factor)),
                                          monitor='Perceptual_loss',
                                          save_best_only=True,
                                          save_weights_only=True,
                                          mode='min',
                                          verbose=1)
        modelcheckpoint.set_model(self.generator)

        # Callback: tensorboard
        if log_tensorboard_path:
            tensorboard = TensorBoard(log_dir=os.path.join(
                log_tensorboard_path, modelname),
                                      histogram_freq=0,
                                      batch_size=batch_size,
                                      write_graph=True,
                                      write_grads=True,
                                      update_freq=log_tensorboard_update_freq)
            tensorboard.set_model(self.rtvsrgan)
        else:
            print(
                ">> Not logging to tensorboard since no log_tensorboard_path is set"
            )

        # Learning rate scheduler
        def lr_scheduler(epoch, lr):
            factor = 0.5
            decay_step = [500, 1000, 1500, 2000]
            if epoch in decay_step and epoch:
                return lr * factor
            return lr

        lr_scheduler_gan = LearningRateScheduler(lr_scheduler, verbose=1)
        lr_scheduler_gan.set_model(self.rtvsrgan)
        lr_scheduler_gen = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_gen.set_model(self.generator)
        lr_scheduler_dis = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_dis.set_model(self.discriminator)
        lr_scheduler_ra = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_ra.set_model(self.ra_discriminator)

        # Callback: format input value
        def named_logs(model, logs):
            """Transform train_on_batch return value to dict expected by on_batch_end callback"""
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        # Shape of output from discriminator
        disciminator_output_shape = list(self.ra_discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.zeros(disciminator_output_shape)

        # Each epoch == "update iteration" as defined in the paper
        print_losses = {"GAN": [], "D": []}
        start_epoch = datetime.datetime.now()

        # Random images to go through
        #idxs = np.random.randint(0, len(train_loader), epochs)
        # Loop through epochs / iterations
        for epoch in range(first_epoch, int(epochs) + first_epoch):
            lr_scheduler_gan.on_epoch_begin(epoch)
            lr_scheduler_ra.on_epoch_begin(epoch)
            lr_scheduler_dis.on_epoch_begin(epoch)
            lr_scheduler_gen.on_epoch_begin(epoch)

            # Start epoch time
            if epoch % print_frequency == 0:
                print("\nEpoch {}/{}:".format(epoch + 1, epochs + first_epoch))
                start_epoch = datetime.datetime.now()

            # Train discriminator
            self.discriminator.trainable = True
            self.ra_discriminator.trainable = True

            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)

            real_loss = self.ra_discriminator.train_on_batch(
                [imgs_hr, generated_hr], real)
            #print("Real: ",real_loss)
            fake_loss = self.ra_discriminator.train_on_batch(
                [generated_hr, imgs_hr], fake)
            #print("Fake: ",fake_loss)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            self.discriminator.trainable = False
            self.ra_discriminator.trainable = False

            for _ in tqdm(range(10), ncols=60, desc=">> Training generator"):
                imgs_lr, imgs_hr = next(output_generator)
                gan_loss = self.rtvsrgan.train_on_batch(
                    [imgs_lr, imgs_hr], [imgs_hr, real, imgs_hr])

            # Callbacks
            logs = named_logs(self.rtvsrgan, gan_loss)
            tensorboard.on_epoch_end(epoch, logs)

            # Callbacks
            if datapath_validation:
                validation_losses = self.generator.evaluate_generator(
                    validation_loader,
                    steps=steps_per_validation,
                    use_multiprocessing=False,  #workers>1,
                    workers=1)
                #logs = named_logs(self.generator, validation_losses)
                modelcheckpoint.on_epoch_end(epoch, logs)

            # Save losses
            print_losses['GAN'].append(gan_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['GAN']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(">> Time: {}s\n>> GAN: {}\n>> Discriminator: {}".format(
                    (datetime.datetime.now() - start_epoch).seconds,
                    ", ".join([
                        "{}={:.4f}".format(k, v) for k, v in zip(
                            self.rtvsrgan.metrics_names, g_avg_loss)
                    ]), ", ".join([
                        "{}={:.4f}".format(k, v) for k, v in zip(
                            self.discriminator.metrics_names, d_avg_loss)
                    ])))
                print_losses = {"GAN": [], "D": []}

                # Run validation inference if specified
                if datapath_validation:
                    print(">> Validation Losses: {}".format(", ".join([
                        "{}={:.4f}".format(k, v) for k, v in zip(
                            self.generator.metrics_names, validation_losses)
                    ])))

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                plot_test_images(self.generator,
                                 test_loader,
                                 datapath_test,
                                 log_test_path,
                                 epoch,
                                 modelname,
                                 channels=self.channels,
                                 colorspace=self.colorspace)

            # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:
                # Save the network weights
                self.save_weights(os.path.join(log_weight_path, modelname))
示例#5
0
embedding_dim = 300

word_indexes = get_word_indexes(word_indexes_filepath)

model = create_model_3(max_caption_len, word_indexes, embedding_dim, feature_dim, embedding_matrix_filepath)

checkpointer = ModelCheckpoint(filepath="model_weights.hdf5", verbose=0)

# model.load_weights('model_weights.hdf5')


while True:
    for i in range(1, 8):
        captions = 'dataset/frequencies/train_captions_' + str(i) + '.npy'
        features = 'dataset/frequencies/train_features_' + str(i) + '.npy'

        print(captions)
        print(features)

        train_captions = get_captions(captions)
        train_features = get_features(features)

        X_captions = train_captions[:, :-1]
        Y = np.expand_dims(train_captions[:, 1:], -1)
        X = [train_features, X_captions]

        try:
            model.fit(X, Y, nb_epoch=1, callbacks=[checkpointer], shuffle=True)
        except KeyboardInterrupt:
            checkpointer.on_epoch_end(0)
示例#6
0
        for i, k in enumerate(args.style_layers):
            # log['style_loss'][k].append(out[offset + i])
            key = "style_loss_%s" % k
            style_block_loss = out[offset + i]
            logs[key] = style_block_loss
            style_loss += style_block_loss
        logs['style_loss'] = style_loss

        stop_time = time.time()
        print('Iteration %d/%d: loss = %f. t = %f (%f)' %
              (it + 1, args.num_iterations, out[0], stop_time - start_time,
               stop_time2 - start_time2))

        if not ((it + 1) % args.save_every):
            print('Saving checkpoint in %s...' % (args.checkpoint_path))
            model_checkpoint.on_epoch_end(it)
            print('Checkpoint saved.')

        # tensorboard
        if (log_dir):
            tensorboard.on_epoch_end(it, logs)

        start_time = time.time()

    # close callback
    if (log_dir):
        tensorboard.on_train_end(None)

    # save model
    pastiche_net.save_weights(weights_path)
示例#7
0
def train_model(model, data, config, include_tensorboard):
	model_history = History()
	model_history.on_train_begin()
	saver = ModelCheckpoint(full_path(config.model_file()), verbose=1, save_best_only=True, period=1)
	saver.set_model(model)
	early_stopping = EarlyStopping(min_delta=config.min_delta, patience=config.patience, verbose=1)
	early_stopping.set_model(model)
	early_stopping.on_train_begin()
	csv_logger = CSVLogger(full_path(config.csv_log_file()))
	csv_logger.on_train_begin()
	if include_tensorboard:
		tensorborad = TensorBoard(histogram_freq=10, write_images=True)
		tensorborad.set_model(model)
	else:
	 tensorborad = Callback()

	epoch = 0
	stop = False
	while(epoch <= config.max_epochs and stop == False):
		epoch_history = History()
		epoch_history.on_train_begin()
		valid_sizes = []
		train_sizes = []
		print("Epoch:", epoch)
		for dataset in data.datasets:
			print("dataset:", dataset.name)
			model.reset_states()
			dataset.reset_generators()

			valid_sizes.append(dataset.valid_generators[0].size())
			train_sizes.append(dataset.train_generators[0].size())
			fit_history = model.fit_generator(dataset.train_generators[0],
				dataset.train_generators[0].size(), 
				nb_epoch=1, 
				verbose=0, 
				validation_data=dataset.valid_generators[0], 
				nb_val_samples=dataset.valid_generators[0].size())

			epoch_history.on_epoch_end(epoch, last_logs(fit_history))

			train_sizes.append(dataset.train_generators[1].size())
			fit_history = model.fit_generator(dataset.train_generators[1],
				dataset.train_generators[1].size(),
				nb_epoch=1, 
				verbose=0)

			epoch_history.on_epoch_end(epoch, last_logs(fit_history))

		epoch_logs = average_logs(epoch_history, train_sizes, valid_sizes)
		model_history.on_epoch_end(epoch, logs=epoch_logs)
		saver.on_epoch_end(epoch, logs=epoch_logs)
		early_stopping.on_epoch_end(epoch, epoch_logs)
		csv_logger.on_epoch_end(epoch, epoch_logs)
		tensorborad.on_epoch_end(epoch, epoch_logs)
		epoch+= 1

		if early_stopping.stopped_epoch > 0:
			stop = True

	early_stopping.on_train_end()
	csv_logger.on_train_end()
	tensorborad.on_train_end({})