def create_initial_model(name): full_filename = os.path.join(conf['MODEL_DIR'], name) + ".h5" if os.path.isfile(full_filename): model = load_model(full_filename, custom_objects={'loss': loss}) return model model = build_model(name) # Save graph in tensorboard. This graph has the name scopes making it look # good in tensorboard, the loaded models will not have the scopes. tf_callback = TensorBoard(log_dir=os.path.join(conf['LOG_DIR'], name), histogram_freq=0, batch_size=1, write_graph=True, write_grads=False) tf_callback.set_model(model) tf_callback.on_epoch_end(0) tf_callback.on_train_end(0) from self_play import self_play self_play(model, n_games=conf['N_GAMES'], mcts_simulations=conf['MCTS_SIMULATIONS']) model.save(full_filename) best_filename = os.path.join(conf['MODEL_DIR'], 'best_model.h5') model.save(best_filename) return model
class Trainer: """Class object to setup and carry the training. Takes as input a generator that produces SR images. Conditionally, also a discriminator network and a feature extractor to build the components of the perceptual loss. Compiles the model(s) and trains in a GANS fashion if a discriminator is provided, otherwise carries a regular ISR training. Args: generator: Keras model, the super-scaling, or generator, network. discriminator: Keras model, the discriminator network for the adversarial component of the perceptual loss. feature_extractor: Keras model, feature extractor network for the deep features component of perceptual loss function. lr_train_dir: path to the directory containing the Low-Res images for training. hr_train_dir: path to the directory containing the High-Res images for training. lr_valid_dir: path to the directory containing the Low-Res images for validation. hr_valid_dir: path to the directory containing the High-Res images for validation. learning_rate: float. loss_weights: dictionary, use to weigh the components of the loss function. Contains 'generator' for the generator loss component, and can contain 'discriminator' and 'feature_extractor' for the discriminator and deep features components respectively. logs_dir: path to the directory where the tensorboard logs are saved. weights_dir: path to the directory where the weights are saved. dataname: string, used to identify what dataset is used for the training session. weights_generator: path to the pre-trained generator's weights, for transfer learning. weights_discriminator: path to the pre-trained discriminator's weights, for transfer learning. n_validation:integer, number of validation samples used at training from the validation set. flatness: dictionary. Determines determines the 'flatness' threshold level for the training patches. See the TrainerHelper class for more details. lr_decay_frequency: integer, every how many epochs the learning rate is reduced. lr_decay_factor: 0 < float <1, learning rate reduction multiplicative factor. Methods: train: combines the networks and triggers training with the specified settings. """ def __init__( self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, loss_weights={ 'generator': 1.0, 'discriminator': 0.003, 'feature_extractor': 1 / 12 }, log_dirs={ 'logs': 'logs', 'weights': 'weights' }, fallback_save_every_n_epochs=2, dataname=None, weights_generator=None, weights_discriminator=None, n_validation=None, flatness={ 'min': 0.0, 'increase_frequency': None, 'increase': 0.0, 'max': 0.0 }, learning_rate={ 'initial_value': 0.0004, 'decay_frequency': 100, 'decay_factor': 0.5 }, adam_optimizer={ 'beta1': 0.9, 'beta2': 0.999, 'epsilon': None }, losses={ 'generator': 'mae', 'discriminator': 'binary_crossentropy', 'feature_extractor': 'mse', }, metrics={'generator': 'PSNR_Y'}, ): self.generator = generator self.discriminator = discriminator self.feature_extractor = feature_extractor self.scale = generator.scale self.lr_patch_size = generator.patch_size self.learning_rate = learning_rate self.loss_weights = loss_weights self.weights_generator = weights_generator self.weights_discriminator = weights_discriminator self.adam_optimizer = adam_optimizer self.dataname = dataname self.flatness = flatness self.n_validation = n_validation self.losses = losses self.log_dirs = log_dirs self.metrics = metrics if self.metrics['generator'] == 'PSNR_Y': self.metrics['generator'] = PSNR_Y elif self.metrics['generator'] == 'PSNR': self.metrics['generator'] = PSNR self._parameters_sanity_check() self.model = self._combine_networks() self.settings = {} self.settings['training_parameters'] = locals() self.settings['training_parameters'][ 'lr_patch_size'] = self.lr_patch_size self.settings = self.update_training_config(self.settings) self.logger = get_logger(__name__) self.helper = TrainerHelper( generator=self.generator, weights_dir=log_dirs['weights'], logs_dir=log_dirs['logs'], lr_train_dir=lr_train_dir, feature_extractor=self.feature_extractor, discriminator=self.discriminator, dataname=dataname, weights_generator=self.weights_generator, weights_discriminator=self.weights_discriminator, fallback_save_every_n_epochs=fallback_save_every_n_epochs, ) self.train_dh = DataHandler( lr_dir=lr_train_dir, hr_dir=hr_train_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=None, ) self.valid_dh = DataHandler( lr_dir=lr_valid_dir, hr_dir=hr_valid_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=n_validation, ) def _parameters_sanity_check(self): """ Parameteres sanity check. """ if self.discriminator: assert self.lr_patch_size * self.scale == self.discriminator.patch_size self.adam_optimizer if self.feature_extractor: assert self.lr_patch_size * self.scale == self.feature_extractor.patch_size check_parameter_keys( self.learning_rate, needed_keys=['initial_value'], optional_keys=['decay_factor', 'decay_frequency'], default_value=None, ) check_parameter_keys( self.flatness, needed_keys=[], optional_keys=['min', 'increase_frequency', 'increase', 'max'], default_value=0.0, ) check_parameter_keys( self.adam_optimizer, needed_keys=['beta1', 'beta2'], optional_keys=['epsilon'], default_value=None, ) check_parameter_keys(self.log_dirs, needed_keys=['logs', 'weights']) def _combine_networks(self): """ Constructs the combined model which contains the generator network, as well as discriminator and geature extractor, if any are defined. """ lr = Input(shape=(self.lr_patch_size, ) * 2 + (3, )) sr = self.generator.model(lr) outputs = [sr] losses = [self.losses['generator']] loss_weights = [self.loss_weights['generator']] if self.discriminator: self.discriminator.model.trainable = False validity = self.discriminator.model(sr) outputs.append(validity) losses.append(self.losses['discriminator']) loss_weights.append(self.loss_weights['discriminator']) if self.feature_extractor: self.feature_extractor.model.trainable = False sr_feats = self.feature_extractor.model(sr) outputs.extend([*sr_feats]) losses.extend([self.losses['feature_extractor']] * len(sr_feats)) loss_weights.extend( [self.loss_weights['feature_extractor'] / len(sr_feats)] * len(sr_feats)) combined = Model(inputs=lr, outputs=outputs) # https://stackoverflow.com/questions/42327543/adam-optimizer-goes-haywire-after-200k-batches-training-loss-grows optimizer = Adam( beta_1=self.adam_optimizer['beta1'], beta_2=self.adam_optimizer['beta2'], lr=self.learning_rate['initial_value'], epsilon=self.adam_optimizer['epsilon'], ) combined.compile(loss=losses, loss_weights=loss_weights, optimizer=optimizer, metrics=self.metrics) return combined def _lr_scheduler(self, epoch): """ Scheduler for the learning rate updates. """ n_decays = epoch // self.learning_rate['decay_frequency'] lr = self.learning_rate['initial_value'] * ( self.learning_rate['decay_factor']**n_decays) # no lr below minimum control 10e-7 return max(1e-7, lr) def _flatness_scheduler(self, epoch): if self.flatness['increase']: n_increases = epoch // self.flatness['increase_frequency'] else: return self.flatness['min'] f = self.flatness['min'] + n_increases * self.flatness['increase'] return min(self.flatness['max'], f) def _load_weights(self): """ Loads the pretrained weights from the given path, if any is provided. If a discriminator is defined, does the same. """ if self.weights_generator: self.model.get_layer('generator').load_weights( self.weights_generator) if self.discriminator: if self.weights_discriminator: self.model.get_layer('discriminator').load_weights( self.weights_discriminator) self.discriminator.model.load_weights( self.weights_discriminator) def _format_losses(self, prefix, losses, model_metrics): """ Creates a dictionary for tensorboard tracking. """ return dict(zip([prefix + m for m in model_metrics], losses)) def update_training_config(self, settings): """ Summarizes training setting. """ _ = settings['training_parameters'].pop('weights_generator') _ = settings['training_parameters'].pop('self') _ = settings['training_parameters'].pop('generator') _ = settings['training_parameters'].pop('discriminator') _ = settings['training_parameters'].pop('feature_extractor') settings['generator'] = {} settings['generator']['name'] = self.generator.name settings['generator']['parameters'] = self.generator.params settings['generator']['weights_generator'] = self.weights_generator _ = settings['training_parameters'].pop('weights_discriminator') if self.discriminator: settings['discriminator'] = {} settings['discriminator']['name'] = self.discriminator.name settings['discriminator'][ 'weights_discriminator'] = self.weights_discriminator else: settings['discriminator'] = None if self.discriminator: settings['feature_extractor'] = {} settings['feature_extractor']['name'] = self.feature_extractor.name settings['feature_extractor'][ 'layers'] = self.feature_extractor.layers_to_extract else: settings['feature_extractor'] = None return settings def train(self, epochs, steps_per_epoch, batch_size, monitored_metrics): """ Carries on the training for the given number of epochs. Sends the losses to Tensorboard. Args: epochs: how many epochs to train for. steps_per_epoch: how many batches epoch. batch_size: amount of images per batch. monitored_metrics: dictionary, the keys are the metrics that are monitored for the weights saving logic. The values are the mode that trigger the weights saving ('min' vs 'max'). """ self.settings['training_parameters'][ 'steps_per_epoch'] = steps_per_epoch self.settings['training_parameters']['batch_size'] = batch_size starting_epoch = self.helper.initialize_training( self) # load_weights, creates folders, creates basename self.tensorboard = TensorBoard( log_dir=self.helper.callback_paths['logs']) self.tensorboard.set_model(self.model) # validation data validation_set = self.valid_dh.get_validation_set(batch_size) y_validation = [validation_set['hr']] if self.discriminator: discr_out_shape = list( self.discriminator.model.outputs[0].shape)[1:4] valid = np.ones([batch_size] + discr_out_shape) fake = np.zeros([batch_size] + discr_out_shape) validation_valid = np.ones([len(validation_set['hr'])] + discr_out_shape) y_validation.append(validation_valid) if self.feature_extractor: validation_feats = self.feature_extractor.model.predict( validation_set['hr']) y_validation.extend([*validation_feats]) for epoch in range(starting_epoch, epochs): self.logger.info('Epoch {e}/{tot_eps}'.format(e=epoch, tot_eps=epochs)) K.set_value(self.model.optimizer.lr, self._lr_scheduler(epoch=epoch)) self.logger.info('Current learning rate: {}'.format( K.eval(self.model.optimizer.lr))) flatness = self._flatness_scheduler(epoch) if flatness: self.logger.info( 'Current flatness treshold: {}'.format(flatness)) epoch_start = time() for step in tqdm(range(steps_per_epoch)): batch = self.train_dh.get_batch(batch_size, flatness=flatness) y_train = [batch['hr']] training_losses = {} ## Discriminator training if self.discriminator: sr = self.generator.model.predict(batch['lr']) d_loss_real = self.discriminator.model.train_on_batch( batch['hr'], valid) d_loss_fake = self.discriminator.model.train_on_batch( sr, fake) d_loss_fake = self._format_losses( 'train_d_fake_', d_loss_fake, self.discriminator.model.metrics_names) d_loss_real = self._format_losses( 'train_d_real_', d_loss_real, self.discriminator.model.metrics_names) training_losses.update(d_loss_real) training_losses.update(d_loss_fake) y_train.append(valid) ## Generator training if self.feature_extractor: hr_feats = self.feature_extractor.model.predict( batch['hr']) y_train.extend([*hr_feats]) model_losses = self.model.train_on_batch(batch['lr'], y_train) model_losses = self._format_losses('train_', model_losses, self.model.metrics_names) training_losses.update(model_losses) self.tensorboard.on_epoch_end(epoch * steps_per_epoch + step, training_losses) self.logger.debug('Losses at step {s}:\n {l}'.format( s=step, l=training_losses)) elapsed_time = time() - epoch_start self.logger.info('Epoch {} took {:10.1f}s'.format( epoch, elapsed_time)) validation_losses = self.model.evaluate(validation_set['lr'], y_validation, batch_size=batch_size) validation_losses = self._format_losses('val_', validation_losses, self.model.metrics_names) if epoch == starting_epoch: remove_metrics = [] for metric in monitored_metrics: if (metric not in training_losses) and ( metric not in validation_losses): msg = ' '.join([ metric, 'is NOT among the model metrics, removing it.' ]) self.logger.error(msg) remove_metrics.append(metric) for metric in remove_metrics: _ = monitored_metrics.pop(metric) # should average train metrics end_losses = {} end_losses.update(validation_losses) end_losses.update(training_losses) self.helper.on_epoch_end( epoch=epoch, losses=end_losses, generator=self.model.get_layer('generator'), discriminator=self.discriminator, metrics=monitored_metrics, ) self.tensorboard.on_epoch_end(epoch, validation_losses) self.tensorboard.on_train_end(None)
'Classifier accuracy for bounding boxes from RPN: {}'. format(class_acc)) print('Loss RPN classifier: {}'.format(loss_rpn_cls)) print('Loss RPN regression: {}'.format(loss_rpn_regr)) print( 'Loss Detector classifier: {}'.format(loss_class_cls)) print( 'Loss Detector regression: {}'.format(loss_class_regr)) print('Elapsed time: {}'.format(time.time() - start_time)) curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr iter_num = 0 start_time = time.time() if curr_loss < best_loss: if C.verbose: print( 'Total loss decreased from {} to {}, saving weights' .format(best_loss, curr_loss)) best_loss = curr_loss model_all.save_weights(C.model_path) break except Exception as e: print('Exception: {}'.format(e)) continue print('Training complete, exiting.') tensorboard.on_train_end(None)
class ExtendedLogger(Callback): val_data_metrics = {} def __init__(self, prediction_layer, output_dir='./tmp', stateful=False, stateful_reset_interval=None, starting_indicies=None): if stateful and stateful_reset_interval is None: raise ValueError( 'If model is stateful, then seq-len has to be defined!') super(ExtendedLogger, self).__init__() self.csv_dir = os.path.join(output_dir, 'csv') self.tb_dir = os.path.join(output_dir, 'tensorboard') self.pred_dir = os.path.join(output_dir, 'predictions') self.plot_dir = os.path.join(output_dir, 'plots') make_dir(self.csv_dir) make_dir(self.tb_dir) make_dir(self.plot_dir) make_dir(self.pred_dir) self.stateful = stateful self.stateful_reset_interval = stateful_reset_interval self.starting_indicies = starting_indicies self.csv_logger = CSVLogger(os.path.join(self.csv_dir, 'run.csv')) self.tensorboard = TensorBoard(log_dir=self.tb_dir, write_graph=True) self.prediction_layer = prediction_layer def set_params(self, params): super(ExtendedLogger, self).set_params(params) self.tensorboard.set_params(params) self.tensorboard.batch_size = params['batch_size'] self.csv_logger.set_params(params) def set_model(self, model): super(ExtendedLogger, self).set_model(model) self.tensorboard.set_model(model) self.csv_logger.set_model(model) def on_batch_begin(self, batch, logs=None): self.csv_logger.on_batch_begin(batch, logs=logs) self.tensorboard.on_batch_begin(batch, logs=logs) def on_batch_end(self, batch, logs=None): self.csv_logger.on_batch_end(batch, logs=logs) self.tensorboard.on_batch_end(batch, logs=logs) def on_train_begin(self, logs=None): self.csv_logger.on_train_begin(logs=logs) self.tensorboard.on_train_begin(logs=logs) def on_train_end(self, logs=None): self.csv_logger.on_train_end(logs=logs) self.tensorboard.on_train_end(logs) def on_epoch_begin(self, epoch, logs=None): self.csv_logger.on_epoch_begin(epoch, logs=logs) self.tensorboard.on_epoch_begin(epoch, logs=logs) def on_epoch_end(self, epoch, logs=None): with timeit('metrics'): outputs = self.model.get_layer(self.prediction_layer).output self.prediction_model = Model(inputs=self.model.input, outputs=outputs) batch_size = self.params['batch_size'] if isinstance(self.validation_data[-1], float): val_data = self.validation_data[:-2] else: val_data = self.validation_data[:-1] y_true = val_data[1] callback = None if self.stateful: callback = ResetStatesCallback( interval=self.stateful_reset_interval) callback.model = self.prediction_model y_pred = self.prediction_model.predict(val_data[:-1], batch_size=batch_size, verbose=1, callback=callback) print(y_true.shape, y_pred.shape) self.write_prediction(epoch, y_true, y_pred) y_true = y_true.reshape((-1, 7)) y_pred = y_pred.reshape((-1, 7)) self.save_error_histograms(epoch, y_true, y_pred) self.save_topview_trajectories(epoch, y_true, y_pred) new_logs = { name: np.array(metric(y_true, y_pred)) for name, metric in self.val_data_metrics.items() } logs.update(new_logs) homo_logs = self.try_add_homoscedastic_params() logs.update(homo_logs) self.tensorboard.validation_data = self.validation_data self.csv_logger.validation_data = self.validation_data self.tensorboard.on_epoch_end(epoch, logs=logs) self.csv_logger.on_epoch_end(epoch, logs=logs) def add_validation_metrics(self, metrics_dict): self.val_data_metrics.update(metrics_dict) def add_validation_metric(self, name, metric): self.val_data_metrics[name] = metric def try_add_homoscedastic_params(self): homo_pos_loss_layer = search_layer(self.model, 'homo_pos_loss') homo_quat_loss_layer = search_layer(self.model, 'homo_quat_loss') if homo_pos_loss_layer: homo_pos_log_vars = np.array(homo_pos_loss_layer.get_weights()[0]) homo_quat_log_vars = np.array( homo_quat_loss_layer.get_weights()[0]) return { 'pos_log_var': np.array(homo_pos_log_vars), 'quat_log_var': np.array(homo_quat_log_vars), } else: return {} def write_prediction(self, epoch, y_true, y_pred): filename = '{:04d}_predictions.npy'.format(epoch) filename = os.path.join(self.pred_dir, filename) arr = {'y_pred': y_pred, 'y_true': y_true} np.save(filename, arr) def save_topview_trajectories(self, epoch, y_true, y_pred, max_segment=1000): if self.starting_indicies is None: self.starting_indicies = {'valid': range(0, 4000, 1000) + [4000]} for begin, end in pairwise(self.starting_indicies['valid']): diff = end - begin if diff > max_segment: subindicies = range(begin, end, max_segment) + [end] for b, e in pairwise(subindicies): self.save_trajectory(epoch, y_true, y_pred, b, e) self.save_trajectory(epoch, y_true, y_pred, begin, end) def save_trajectory(self, epoch, y_true, y_pred, begin, end): true_xy, pred_xy = y_true[begin:end, :2], y_pred[begin:end, :2] true_q = quaternion.as_quat_array(y_true[begin:end, [6, 3, 4, 5]]) true_q = quaternion.as_euler_angles(true_q)[1] pred_q = quaternion.as_quat_array(y_pred[begin:end, [6, 3, 4, 5]]) pred_q = quaternion.as_euler_angles(pred_q)[1] plt.clf() plt.plot(true_xy[:, 0], true_xy[:, 1], 'g-') plt.plot(pred_xy[:, 0], pred_xy[:, 1], 'r-') for ((x1, y1), (x2, y2)) in zip(true_xy, pred_xy): plt.plot([x1, x2], [y1, y2], color='k', linestyle='-', linewidth=0.3, alpha=0.2) plt.grid(True) plt.xlabel('x [m]') plt.ylabel('y [m]') plt.title('Top-down view of trajectory') plt.axis('equal') x_range = (np.min(true_xy[:, 0]) - .2, np.max(true_xy[:, 0]) + .2) y_range = (np.min(true_xy[:, 1]) - .2, np.max(true_xy[:, 1]) + .2) plt.xlim(x_range) plt.ylim(y_range) filename = 'epoch={epoch:04d}_begin={begin:04d}_end={end:04d}_trajectory.pdf' \ .format(epoch=epoch, begin=begin, end=end) filename = os.path.join(self.plot_dir, filename) plt.savefig(filename) def save_error_histograms(self, epoch, y_true, y_pred): pos_errors = PoseMetrics.abs_errors_position(y_true, y_pred) pos_errors = np.sort(pos_errors) angle_errors = PoseMetrics.abs_errors_orienation(y_true, y_pred) angle_errors = np.sort(angle_errors) size = len(y_true) ys = np.arange(size) / float(size) plt.clf() plt.subplot(2, 1, 1) plt.title('Empirical CDF of absolute errors') plt.grid(True) plt.plot(pos_errors, ys, 'k-') plt.xlabel('Absolute Position Error (m)') plt.xlim(0, 1.2) plt.subplot(2, 1, 2) plt.grid(True) plt.plot(angle_errors, ys, 'r-') plt.xlabel('Absolute Angle Error (deg)') plt.xlim(0, 70) filename = '{:04d}_cdf.pdf'.format(epoch) filename = os.path.join(self.plot_dir, filename) plt.savefig(filename)
def train(self, epochs, batch_size=BATCH_SIZE, sample_interval=50): tensorboard = TensorBoard(log_dir=LOG_DIR) tensorboard.set_model(self.discriminator) for epoch in range(epochs): # --------------------- # Train Discriminator # --------------------- # Detect batch size in npys batch_size = min(self.this_npy_num_imgs, batch_size) idx = np.random.randint(0, self.this_npy_num_imgs-1, batch_size) # Select a random batch of images #self.X_train = os.path.join(OUTPATH, np.random.choice(os.listdir(OUTPATH))) self.X_train = np.load(train_robin, allow_pickle=True) self.X_train = self.X_train[idx] self.X_train = np.expand_dims(self.X_train, axis=3) self.X_train = self.X_train / (255/2) - 1 noise = np.random.normal(-1, 1, ((batch_size,) + self.latent_dim)) # Adversarial ground truths valid = np.ones(self.X_train.shape) fake = np.zeros(self.X_train.shape) # Generate a batch of new images gen_imgs = self.generator.predict(noise) if epoch == 0 or accuracy < 80: # Train the discriminator d_loss_real = self.discriminator.train_on_batch(self.X_train, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) else: # Test the discriminator d_loss_real = self.discriminator.test_on_batch(self.X_train, valid) d_loss_fake = self.discriminator.test_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) accuracy = 100*d_loss[1] # --------------------- # Train Generator # --------------------- noise = np.random.normal(-1, 1, ((batch_size,) + self.latent_dim)) if epoch == 0 or accuracy > 20: # Train the generator (to have the discriminator label samples # as valid) g_loss = self.combined.train_on_batch(noise, valid) else: # Train the generator (to have the discriminator label samples # as valid) g_loss = self.combined.test_on_batch(noise, valid) tensorboard.on_epoch_end(epoch, {'generator loss': g_loss, 'discriminator loss': d_loss[0], 'Accuracy': accuracy, 'Comb. loss': g_loss + d_loss[0]}) # If at save interval => save generated image samples if epoch % sample_interval == 0: print(f"@ {epoch:{len(str(EPOCHS))}}:\t" f"Accuracy: {int(accuracy):3}%\t" f"G-Loss: {g_loss:6.3f}\t" f"D-Loss: {d_loss[0]:6.3f}\t" f"Combined: {g_loss+d_loss[0]:6.3f}") self.sample_images(epoch) tensorboard.on_train_end(tensorboard) self.discriminator.save('discriminator.h5') self.generator.save('generator.h5')
def train(self, epochs, batch_size=BATCH_SIZE, sample_interval=50): tensorboard = TensorBoard(log_dir=LOG_DIR) tensorboard.set_model(self.discriminator) for epoch in range(epochs): # --------------------- # Train Discriminator # --------------------- # Detect batch size in npys batch_size = min(self.img_per_npy, batch_size) # Select a random batch of images self.X_train = os.path.join(OUTPATH, np.random.choice(os.listdir(OUTPATH))) self.X_train = np.load(self.X_train, allow_pickle=True) idx = np.random.randint(0, len(self.X_train), batch_size) self.X_train = self.X_train[idx] self.X_train = np.expand_dims(self.X_train, axis=3) self.X_train = self.X_train / (-255/2) + 1 noise = np.random.normal(-1, 1, ((batch_size,) + LATENT_SIZE)) # Adversarial ground truths valid = np.ones((batch_size,)) if ADD_LABEL_NOISE: valid -= np.random.uniform(high=LABEL_NOISE, size=(batch_size,)) for img in range(batch_size): if np.random.rand() < P_FLIP_LABEL: valid[img] = 1 - valid[img] fake = np.zeros((batch_size,)) if ADD_LABEL_NOISE: fake += np.random.uniform(high=LABEL_NOISE, size=(batch_size,)) print(fake) for img in range(batch_size): if np.random.rand() < P_FLIP_LABEL: fake[img] = 1 - fake[img] # Generate a batch of new images gen_imgs = self.generator.predict(noise) if epoch == 0 or accuracy < 80: # Train the discriminator d_loss_real = self.discriminator.train_on_batch(self.X_train, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) else: # Test the discriminator d_loss_real = self.discriminator.test_on_batch(self.X_train, valid) d_loss_fake = self.discriminator.test_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) accuracy = 100*d_loss[1] # --------------------- # Train Generator # --------------------- noise = np.random.normal(-1, 1, ((batch_size,) + LATENT_SIZE)) if epoch == 0 or accuracy > 52: # Train the generator (to have the discriminator label samples # as valid) g_loss = self.combined.train_on_batch(noise, valid) else: # Train the generator (to have the discriminator label samples # as valid) g_loss = self.combined.test_on_batch(noise, valid) tensorboard.on_epoch_end(epoch, {'generator loss': g_loss, 'discriminator loss': d_loss[0], 'Accuracy': accuracy, 'Comb. loss': g_loss + d_loss[0]}) # If at save interval => save generated image samples if epoch % sample_interval == 0: print(f"@ {epoch:{len(str(EPOCHS))}}:\t" f"Accuracy: {int(accuracy):3}%\t" f"G-Loss: {g_loss:6.3f}\t" f"D-Loss: {d_loss[0]:6.3f}\t" f"Combined: {g_loss+d_loss[0]:6.3f}") self.sample_images(epoch, accuracy, real_imgs=self.X_train) if epoch % SAVE_INTERVAL == 0: self.discriminator.save('discriminator.h5') self.generator.save('generator.h5') tensorboard.on_train_end(tensorboard)
class Trainer: """Class object to setup and carry the training. Takes as input a generator that produces SR images. Conditionally, also a discriminator network and a feature extractor to build the components of the perceptual loss. Compiles the model(s) and trains in a GANS fashion if a discriminator is provided, otherwise carries a regular ISR training. Args: generator: Keras model, the super-scaling, or generator, network. discriminator: Keras model, the discriminator network for the adversarial component of the perceptual loss. feature_extractor: Keras model, feature extractor network for the deep features component of perceptual loss function. lr_train_dir: path to the directory containing the Low-Res images for training. hr_train_dir: path to the directory containing the High-Res images for training. lr_valid_dir: path to the directory containing the Low-Res images for validation. hr_valid_dir: path to the directory containing the High-Res images for validation. learning_rate: float. loss_weights: dictionary, use to weigh the components of the loss function. Contains 'MSE' for the MSE loss component, and can contain 'discriminator' and 'feat_extr' for the discriminator and deep features components respectively. logs_dir: path to the directory where the tensorboard logs are saved. weights_dir: path to the directory where the weights are saved. dataname: string, used to identify what dataset is used for the training session. weights_generator: path to the pre-trained generator's weights, for transfer learning. weights_discriminator: path to the pre-trained discriminator's weights, for transfer learning. n_validation:integer, number of validation samples used at training from the validation set. T: 0 < float <1, determines the 'flatness' threshold level for the training patches. See the TrainerHelper class for more details. lr_decay_frequency: integer, every how many epochs the learning rate is reduced. lr_decay_factor: 0 < float <1, learning rate reduction multiplicative factor. Methods: train: combines the networks and triggers training with the specified settings. """ def __init__( self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, learning_rate=0.0004, loss_weights={'MSE': 1.0}, logs_dir='logs', weights_dir='weights', dataname=None, weights_generator=None, weights_discriminator=None, n_validation=None, T=0.01, lr_decay_frequency=100, lr_decay_factor=0.5, ): if discriminator: assert generator.patch_size * generator.scale == discriminator.patch_size if feature_extractor: assert generator.patch_size * generator.scale == feature_extractor.patch_size self.generator = generator self.discriminator = discriminator self.feature_extractor = feature_extractor self.scale = generator.scale self.lr_patch_size = generator.patch_size self.learning_rate = learning_rate self.loss_weights = loss_weights self.best_metrics = {} self.pretrained_weights_path = { 'generator': weights_generator, 'discriminator': weights_discriminator, } self.lr_decay_factor = lr_decay_factor self.lr_decay_frequency = lr_decay_frequency self.helper = TrainerHelper( generator=self.generator, weights_dir=weights_dir, logs_dir=logs_dir, lr_train_dir=lr_train_dir, feature_extractor=self.feature_extractor, discriminator=self.discriminator, dataname=dataname, pretrained_weights_path=self.pretrained_weights_path, ) self.model = self._combine_networks() self.train_dh = DataHandler( lr_dir=lr_train_dir, hr_dir=hr_train_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=None, T=T, ) self.valid_dh = DataHandler( lr_dir=lr_valid_dir, hr_dir=hr_valid_dir, patch_size=self.lr_patch_size, scale=self.scale, n_validation_samples=n_validation, T=0.01, ) self.logger = get_logger(__name__) def _combine_networks(self): """ Constructs the combined model which contains the generator network, as well as discriminator and geature extractor, if any are defined. """ lr = Input(shape=(self.lr_patch_size, ) * 2 + (3, )) sr = self.generator.model(lr) outputs = [sr] losses = ['mse'] loss_weights = [self.loss_weights['MSE']] if self.discriminator: self.discriminator.model.trainable = False validity = self.discriminator.model(sr) outputs.append(validity) losses.append('binary_crossentropy') loss_weights.append(self.loss_weights['discriminator']) if self.feature_extractor: self.feature_extractor.model.trainable = False sr_feats = self.feature_extractor.model(sr) outputs.extend([*sr_feats]) losses.extend(['mse'] * len(sr_feats)) loss_weights.extend( [self.loss_weights['feat_extr'] / len(sr_feats)] * len(sr_feats)) combined = Model(inputs=lr, outputs=outputs) # https://stackoverflow.com/questions/42327543/adam-optimizer-goes-haywire-after-200k-batches-training-loss-grows optimizer = Adam(epsilon=0.0000001) combined.compile(loss=losses, loss_weights=loss_weights, optimizer=optimizer, metrics={'generator': PSNR}) return combined def _lr_scheduler(self, epoch): """ Scheduler for the learning rate updates. """ n_decays = epoch // self.lr_decay_frequency # no lr below minimum control 10e-6 return max(1e-6, self.learning_rate * (self.lr_decay_factor**n_decays)) def _load_weights(self): """ Loads the pretrained weights from the given path, if any is provided. If a discriminator is defined, does the same. """ gen_w = self.pretrained_weights_path['generator'] if gen_w: self.model.get_layer('generator').load_weights(gen_w) if self.discriminator: dis_w = self.pretrained_weights_path['discriminator'] if dis_w: self.model.get_layer('discriminator').load_weights(dis_w) self.discriminator.model.load_weights(dis_w) def train(self, epochs, steps_per_epoch, batch_size): """ Carries on the training for the given number of epochs. Sends the losses to Tensorboard. """ starting_epoch = self.helper.initialize_training( self) # load_weights, creates folders, creates basename self.tensorboard = TensorBoard( log_dir=self.helper.callback_paths['logs']) self.tensorboard.set_model(self.model) # validation data validation_set = self.valid_dh.get_validation_set(batch_size) y_validation = [validation_set['hr']] if self.discriminator: discr_out_shape = list( self.discriminator.model.outputs[0].shape)[1:4] valid = np.ones([batch_size] + discr_out_shape) fake = np.zeros([batch_size] + discr_out_shape) validation_valid = np.ones([len(validation_set['hr'])] + discr_out_shape) y_validation.append(validation_valid) if self.feature_extractor: validation_feats = self.feature_extractor.model.predict( validation_set['hr']) y_validation.extend([*validation_feats]) for epoch in range(starting_epoch, epochs): self.logger.info('Epoch {e}/{tot_eps}'.format(e=epoch, tot_eps=epochs)) K.set_value(self.model.optimizer.lr, self._lr_scheduler(epoch=epoch)) self.logger.info('Current learning rate: {}'.format( K.eval(self.model.optimizer.lr))) epoch_start = time() for step in tqdm(range(steps_per_epoch)): batch = self.train_dh.get_batch(batch_size) sr = self.generator.model.predict(batch['lr']) y_train = [batch['hr']] losses = {} ## Discriminator training if self.discriminator: d_loss_real = self.discriminator.model.train_on_batch( batch['hr'], valid) d_loss_fake = self.discriminator.model.train_on_batch( sr, fake) d_loss_real = dict( zip( [ 'train_d_real_' + m for m in self.discriminator.model.metrics_names ], d_loss_real, )) d_loss_fake = dict( zip( [ 'train_d_fake_' + m for m in self.discriminator.model.metrics_names ], d_loss_fake, )) losses.update(d_loss_real) losses.update(d_loss_fake) y_train.append(valid) ## Generator training if self.feature_extractor: hr_feats = self.feature_extractor.model.predict( batch['hr']) y_train.extend([*hr_feats]) trainig_loss = self.model.train_on_batch(batch['lr'], y_train) losses.update( dict( zip(['train_' + m for m in self.model.metrics_names], trainig_loss))) self.tensorboard.on_epoch_end(epoch * steps_per_epoch + step, losses) self.logger.debug('Losses at step {s}:\n {l}'.format(s=step, l=losses)) elapsed_time = time() - epoch_start self.logger.info('Epoch {} took {:10.1f}s'.format( epoch, elapsed_time)) validation_loss = self.model.evaluate(validation_set['lr'], y_validation, batch_size=batch_size) losses = dict( zip(['val_' + m for m in self.model.metrics_names], validation_loss)) monitored_metrics = {} if (not self.discriminator) and (not self.feature_extractor): monitored_metrics.update({'val_loss': 'min'}) else: monitored_metrics.update({'val_generator_loss': 'min'}) self.helper.on_epoch_end( epoch=epoch, losses=losses, generator=self.model.get_layer('generator'), discriminator=self.discriminator, metrics=monitored_metrics, ) self.tensorboard.on_epoch_end(epoch, losses) self.tensorboard.on_train_end(None)
for i in range(len(x)): # train disc on real disc.train_on_batch([x[i], y[i]], real_y) # gen fake fake = gen.predict(x[i]) # train disc on fake disc.train_on_batch([x[i], fake], fake_y) # train combined disc.trainable = False combined.train_on_batch(x[i], [y[i], real_y]) disc.trainable = True #log.write(str(e) + ", " + str(s) + ", " + str(dr_loss) + ", " + str(df_loss) + ", " + str(g_loss[0]) + ", " + str(g_loss[1]) + ", " + str(opt_dcgan.get_config()["lr"]) + "\n") # output random result #val_sequence = sequences[train_offset:] #generated_y = gen.predict(x[random_index]) #save_image(strip(x[random_index]) / 2 + 0.5, y[random_index], re_shape(generated_y), "validation/e{}_{}.png".format(e, s)) # save weights gen.save_weights(checkpoint_gen_name, overwrite=True) disc.save_weights(checkpoint_disc_name, overwrite=True) tensorlog.on_epoch_end(e) tensorlog.on_train_end()
def train(self, epochs, batch_size=128, sample_interval=100): (X_train, _), (_, _) = mnist.load_data() # Normalization to the scale -1 to 1 # X_train = X_train.astype('float32') X_train = X_train / 127.5 - 1. X_train = np.expand_dims(X_train, axis=3) # Create the labels valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) tensorboard = TensorBoard(log_dir='./tmp/logs', histogram_freq=0, write_graph=True) tensorboard.set_model(self.combined) g_loss_list = [] d_loss_list = [] for epoch in range(epochs): # ----------------------------- # # Randomly pick batch imags # to train the discriminator # ----------------------------- # # Randomly pick batch imags idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] # Generate batch-size of random noise with latent dimension size noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) gen_imgs = self.generator.predict(noise) # The loss of discriminator d_loss_real = self.discriminator.train_on_batch(imgs, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # Avg loss d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) d_loss_list.append(d_loss[0]) # # --------------------------- # # # Train the generator # # --------------------------- # noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) g_loss = self.combined.train_on_batch(noise, valid) g_loss_list.append(g_loss) print_str = "%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss) print(print_str) tensorboard.on_epoch_end(epoch, self._named_logs(['d_loss', 'd_accuracy'], d_loss)) tensorboard.on_epoch_end(epoch, self._named_logs(['g_loss'], [g_loss])) if epoch % sample_interval == 0: self.sample_images(epoch) tensorboard.on_train_end(None) # Save the loss image plt.xlabel("Steps") plt.ylabel("Loss") plt.title("Loss of Generator and Discriminator") plt.plot(np.arange(epochs), g_loss_list, label="Generator Loss") plt.plot(np.arange(epochs), d_loss_list, label="Discriminator Loss") plt.legend() plt.savefig('epoch {}.png'.format(epochs)) print("Traning finished.\n...\n") # Model Saving self.generator.save("models/generator.tf") self.discriminator.save("models/discriminator.tf") self.combined.save("models/gan.tf") print("Model Saved.\n-------------------------------------------------------------------")
class Network: """ A deep convolutional neural network that takes as input the ML representation of a game (n, m, k) and tries to return 1 if player 1 (white in chess, red in checkers etc.) is going to win, -1 if player 2 is going to win, and 0 if the game is going to be a draw. The Network's model will input a binary matrix with a shape of GameClass.STATE_SHAPE and output a tuple consisting of the probability distribution over legal moves and the position's evaluation. """ def __init__(self, GameClass, model_path=None, reinforcement_training=False, hyper_params=None): self.GameClass = GameClass self.model_path = model_path # lazily initialized so Network can be passed between processes before being initialized self.model = None self.reinforcement_training = reinforcement_training self.hyper_params = hyper_params if hyper_params is not None else {} self.tensor_board = None self.epoch = 0 def initialize(self): """ Initializes the Network's model. """ # Note: keras imports are within functions to prevent initializing keras in processes that import from this file from keras.models import load_model from keras.callbacks import TensorBoard if self.model is not None: return if self.model_path is not None: input_shape = self.GameClass.STATE_SHAPE output_shape = self.GameClass.MOVE_SHAPE self.model = load_model(self.model_path) if self.model.input_shape != (None, ) + input_shape: raise Exception('Input shape of loaded model doesn\'t match!') if self.model.output_shape != [(None, ) + output_shape, (None, 1)]: raise Exception('Output shape of loaded model doesn\'t match!') # TODO: recompile model with loss_weights and learning schedule from config file else: self.model = self.create_model(**self.hyper_params) if self.reinforcement_training: self.tensor_board = TensorBoard( log_dir=f'{get_training_path(self.GameClass)}/logs/' f'model_reinforcement_{time()}', histogram_freq=0, write_graph=True) self.tensor_board.set_model(self.model) def create_model(self, kernel_size=(4, 4), convolutional_filters=64, residual_layers=6, value_head_neurons=16, policy_loss_value=1): """ https://www.youtube.com/watch?v=OPgRNY3FaxA """ # Note: keras imports are within functions to prevent initializing keras in processes that import from this file from keras.models import Model from keras.layers import Input, Conv2D, BatchNormalization, Flatten, Dense, Activation, Add, Reshape input_shape = self.GameClass.STATE_SHAPE output_shape = self.GameClass.MOVE_SHAPE output_neurons = np.product(output_shape) input_tensor = Input(input_shape) # convolutional layer x = Conv2D(convolutional_filters, kernel_size, padding='same')(input_tensor) x = BatchNormalization()(x) x = Activation('relu')(x) # residual layers for _ in range(residual_layers): y = Conv2D(convolutional_filters, kernel_size, padding='same')(x) y = BatchNormalization()(y) y = Activation('relu')(y) y = Conv2D(convolutional_filters, kernel_size, padding='same')(y) y = BatchNormalization()(y) # noinspection PyTypeChecker x = Add()([x, y]) x = Activation('relu')(x) # policy head policy = Conv2D(2, (1, 1), padding='same')(x) policy = BatchNormalization()(policy) policy = Activation('relu')(policy) policy = Flatten()(policy) policy = Dense(output_neurons, activation='softmax')(policy) policy = Reshape(output_shape, name='policy')(policy) # value head value = Conv2D(1, (1, 1), padding='same')(x) value = BatchNormalization()(value) value = Activation('relu')(value) value = Flatten()(value) value = Dense(value_head_neurons, activation='relu')(value) value = Dense(1, activation='tanh', name='value')(value) model = Model(input_tensor, [policy, value]) model.compile(optimizer='adam', loss={ 'policy': 'categorical_crossentropy', 'value': 'mean_squared_error' }, loss_weights={ 'policy': policy_loss_value, 'value': 1 }, metrics=['mean_squared_error']) return model def predict(self, states): return self.model.predict(states) def call(self, states): """ For any of the given states, if no moves are legal, then the corresponding probability distribution will be a list with a single 1. This is done to allow for pass moves which are not encapsulated by GameClass.MOVE_SHAPE. :param states: The input positions with shape (k,) + GameClass.STATE_SHAPE, where k is the number of positions. :return: A list of length k. Each element of the list is a tuple where the 0th element is the probability distribution on legal moves (ordered correspondingly with GameClass.get_possible_moves), and the 1st element is the evaluation (a float in (-1, 1)). """ raw_policies, evaluations = self.predict(states) filtered_policies = [ raw_policy[self.GameClass.get_legal_moves(state)] for state, raw_policy in zip(states, raw_policies) ] filtered_policies = [ filtered_policy / np.sum(filtered_policy) if len(filtered_policy) > 0 else [1] for filtered_policy in filtered_policies ] evaluations = evaluations.reshape(states.shape[0]) return [(filtered_policy, evaluation) for filtered_policy, evaluation in zip(filtered_policies, evaluations)] def choose_move(self, position, return_distribution=False, optimal=False): distribution, evaluation = self.call(position[np.newaxis, ...])[0] idx = np.argmin(distribution) if optimal else np.random.choice( np.arange(len(distribution)), p=distribution) move = self.GameClass.get_possible_moves(position)[idx] return (move, distribution) if return_distribution else move def train(self, data, validation_fraction=0.2): # Note: keras imports are within functions to prevent initializing keras in processes that import from this file from keras.callbacks import TensorBoard, EarlyStopping split = int((1 - validation_fraction) * len(data)) train_input, train_output = self.get_training_data( self.GameClass, data[:split]) test_input, test_output = self.get_training_data( self.GameClass, data[split:]) print('Training Samples:', train_input.shape[0]) print('Validation Samples:', test_input.shape[0]) self.model.fit( train_input, train_output, epochs=100, validation_data=(test_input, test_output), callbacks=[ TensorBoard( log_dir= f'{get_training_path(self.GameClass)}/logs/model_{time()}' ), EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) ]) def train_step(self, states, policies, values): logs = self.model.train_on_batch(states, [policies, values], return_dict=True) self.tensor_board.on_epoch_end(self.epoch, logs) self.epoch += 1 def finish_training(self): self.tensor_board.on_train_end() def save(self, model_path): self.model.save(model_path) def equal_model_architecture(self, network): """ Both networks must be initialized. :return: True if this Network's model and the given network's model have the same architecture. """ return self.model.get_config() == network.model.get_config() @classmethod def get_training_data(cls, GameClass, data, one_hot=False, shuffle=True): """ :param GameClass: :param data: A list of game, outcome tuples. Each game is a list of position, distribution tuples. :param one_hot: :param shuffle: :return: """ states = [] policy_outputs = [] value_outputs = [] for game, outcome in data: for position, distribution in game: legal_moves = GameClass.get_legal_moves(position) policy = np.zeros_like(legal_moves, dtype=float) policy[legal_moves] = distribution policy /= np.sum(policy) # rescale so total probability is 1 if one_hot: idx = np.unravel_index(policy.argmax(), policy.shape) policy = np.zeros_like(policy) policy[idx] = 1 states.append(position) policy_outputs.append(policy) value_outputs.append(outcome) input_data = np.stack(states, axis=0) policy_outputs = np.stack(policy_outputs, axis=0) value_outputs = np.array(value_outputs) if shuffle: shuffle_indices = np.arange(input_data.shape[0]) np.random.shuffle(shuffle_indices) input_data = input_data[shuffle_indices, ...] policy_outputs = policy_outputs[shuffle_indices, ...] value_outputs = value_outputs[shuffle_indices] return input_data, [policy_outputs, value_outputs]
def train(self, batch_size=4, epochs=25): cf = self.cf self.compile() model = self.keras_model word_vectors, char_vectors, train_ques_ids, X_train, y_train, val_ques_ids, X_valid, y_valid = self.data_train qanet_cb = QANetCallback(decay=cf.EMA_DECAY) tb = TensorBoard(log_dir=cf.TENSORBOARD_PATH, histogram_freq=0, write_graph=False, write_images=False, update_freq=cf.TENSORBOARD_UPDATE_FREQ) # Call set_model for all callbacks qanet_cb.set_model(model) tb.set_model(model) ep_list = [] avg_train_loss_list = [] em_score_list = [] f1_score_list = [] global_steps = 0 gt_start_list, gt_end_list = y_valid[2:] for ep in range(1, epochs + 1): # Epoch num start from 1 print('----------- Training for epoch {}...'.format(ep)) # Train batch = 0 sum_loss = 0 num_batches = (len(X_train[0]) - 1) // batch_size + 1 for X_batch, y_batch in get_batch(X_train, y_train, batch_size=batch_size, shuffle=True): batch_logs = {'batch': batch, 'size': len(X_batch[0])} tb.on_batch_begin(batch, batch_logs) loss, loss_p1, loss_p2, loss_start, loss_end = model.train_on_batch( X_batch, y_batch) sum_loss += loss avg_loss = sum_loss / (batch + 1) print( 'Epoch: {}/{}, Batch: {}/{}, Accumulative average loss: {:.4f}, Loss: {:.4f}, Loss_P1: {:.4f}, Loss_P2: {:.4f}, Loss_start: {:.4f}, Loss_end: {:.4f}' .format(ep, epochs, batch, num_batches, avg_loss, loss, loss_p1, loss_p2, loss_start, loss_end)) batch_logs.update({ 'loss': loss, 'loss_p1': loss_p1, 'loss_p2': loss_p2 }) qanet_cb.on_batch_end(batch, batch_logs) tb.on_batch_end(batch, batch_logs) global_steps += 1 batch += 1 ep_list.append(ep) avg_train_loss_list.append(avg_loss) print('Backing up temp weights...') model.save_weights(cf.TEMP_MODEL_PATH) qanet_cb.on_epoch_end(ep) # Apply EMA weights model.save_weights(cf.MODEL_PATH % str(ep)) print('----------- Validating for epoch {}...'.format(ep)) valid_scores = self.validate(X_valid, y_valid, gt_start_list, gt_end_list, batch_size=cf.BATCH_SIZE) em_score_list.append(valid_scores['exact_match']) f1_score_list.append(valid_scores['f1']) print( '------- Result of epoch: {}/{}, Average_train_loss: {:.6f}, EM: {:.4f}, F1: {:.4f}\n' .format(ep, epochs, avg_loss, valid_scores['exact_match'], valid_scores['f1'])) tb.on_epoch_end(ep, { 'f1': valid_scores['f1'], 'em': valid_scores['exact_match'] }) # Write result to CSV file result = pd.DataFrame({ 'epoch': ep_list, 'avg_train_loss': avg_train_loss_list, 'em': em_score_list, 'f1': f1_score_list }) result.to_csv(cf.RESULT_LOG, index=None) # Restore the original weights to continue training print('Restoring temp weights...') model.load_weights(cf.TEMP_MODEL_PATH) tb.on_train_end(None)
def train_model(model, data, config, include_tensorboard): model_history = History() model_history.on_train_begin() saver = ModelCheckpoint(full_path(config.model_file()), verbose=1, save_best_only=True, period=1) saver.set_model(model) early_stopping = EarlyStopping(min_delta=config.min_delta, patience=config.patience, verbose=1) early_stopping.set_model(model) early_stopping.on_train_begin() csv_logger = CSVLogger(full_path(config.csv_log_file())) csv_logger.on_train_begin() if include_tensorboard: tensorborad = TensorBoard(histogram_freq=10, write_images=True) tensorborad.set_model(model) else: tensorborad = Callback() epoch = 0 stop = False while(epoch <= config.max_epochs and stop == False): epoch_history = History() epoch_history.on_train_begin() valid_sizes = [] train_sizes = [] print("Epoch:", epoch) for dataset in data.datasets: print("dataset:", dataset.name) model.reset_states() dataset.reset_generators() valid_sizes.append(dataset.valid_generators[0].size()) train_sizes.append(dataset.train_generators[0].size()) fit_history = model.fit_generator(dataset.train_generators[0], dataset.train_generators[0].size(), nb_epoch=1, verbose=0, validation_data=dataset.valid_generators[0], nb_val_samples=dataset.valid_generators[0].size()) epoch_history.on_epoch_end(epoch, last_logs(fit_history)) train_sizes.append(dataset.train_generators[1].size()) fit_history = model.fit_generator(dataset.train_generators[1], dataset.train_generators[1].size(), nb_epoch=1, verbose=0) epoch_history.on_epoch_end(epoch, last_logs(fit_history)) epoch_logs = average_logs(epoch_history, train_sizes, valid_sizes) model_history.on_epoch_end(epoch, logs=epoch_logs) saver.on_epoch_end(epoch, logs=epoch_logs) early_stopping.on_epoch_end(epoch, epoch_logs) csv_logger.on_epoch_end(epoch, epoch_logs) tensorborad.on_epoch_end(epoch, epoch_logs) epoch+= 1 if early_stopping.stopped_epoch > 0: stop = True early_stopping.on_train_end() csv_logger.on_train_end() tensorborad.on_train_end({})