def build_callbacks(self, save_dir, logger, **kwargs): metrics = kwargs.get('metrics', 'accuracy') if isinstance(metrics, (list, tuple)): metrics = metrics[-1] monitor = f'val_{metrics}' checkpoint = tf.keras.callbacks.ModelCheckpoint( os.path.join(save_dir, 'model.h5'), # verbose=1, monitor=monitor, save_best_only=True, mode='max', save_weights_only=True) logger.debug(f'Monitor {checkpoint.monitor} for checkpoint') tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=io_util.makedirs(io_util.path_join(save_dir, 'logs'))) csv_logger = FineCSVLogger(os.path.join(save_dir, 'train.log'), separator=' | ', append=True) callbacks = [checkpoint, tensorboard_callback, csv_logger] lr_decay_per_epoch = self.config.get('lr_decay_per_epoch', None) if lr_decay_per_epoch: learning_rate = self.model.optimizer.get_config().get('learning_rate', None) if not learning_rate: logger.warning('Learning rate decay not supported for optimizer={}'.format(repr(self.model.optimizer))) else: logger.debug(f'Created LearningRateScheduler with lr_decay_per_epoch={lr_decay_per_epoch}') callbacks.append(tf.keras.callbacks.LearningRateScheduler( lambda epoch: learning_rate / (1 + lr_decay_per_epoch * epoch))) anneal_factor = self.config.get('anneal_factor', None) if anneal_factor: callbacks.append(tf.keras.callbacks.ReduceLROnPlateau(factor=anneal_factor, patience=self.config.get('anneal_patience', 10))) early_stopping_patience = self.config.get('early_stopping_patience', None) if early_stopping_patience: callbacks.append(tf.keras.callbacks.EarlyStopping(monitor=monitor, mode='max', verbose=1, patience=early_stopping_patience)) return callbacks
def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True, **kwargs): self._capture_config(locals()) self.transform = self.build_transform(**self.config) if not save_dir: save_dir = tempdir_human() if not logger: logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN) logger.info('Hyperparameter:\n' + self.config.to_json()) num_examples = self.build_vocab(trn_data, logger) # assert num_examples, 'You forgot to return the number of training examples in your build_vocab' logger.info('Building...') train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True)) logger.info('Model built:\n' + summary_of_model(self.model)) self.save_config(save_dir) self.save_vocabs(save_dir) self.save_meta(save_dir) trn_data = self.build_train_dataset(trn_data, batch_size, num_examples) dev_data = self.build_valid_dataset(dev_data, batch_size) callbacks = self.build_callbacks(save_dir, logger, **self.config) # need to know #batches, otherwise progbar crashes dev_steps = math.ceil(size_of_dataset(dev_data) / batch_size) checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint) timer = Timer() try: history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs, num_examples=num_examples, train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps, callbacks=callbacks, logger=logger, model=model, optimizer=optimizer, loss=loss, metrics=metrics, overwrite=True)) except KeyboardInterrupt: print() if not checkpoint or checkpoint.best in (np.Inf, -np.Inf): self.save_weights(save_dir) logger.info('Aborted with model saved') else: logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}') # noinspection PyTypeChecker history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History) delta_time = timer.stop() best_epoch_ago = 0 if history and hasattr(history, 'epoch'): trained_epoch = len(history.epoch) logger.info('Trained {} epochs in {}, each epoch takes {}'. format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time)) io_util.save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=io_util.NumpyEncoder) monitor_history: List = history.history.get(checkpoint.monitor, None) if monitor_history: best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best) if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]: logger.info(f'Restored the best model saved with best ' f'{checkpoint.monitor} = {checkpoint.best:.4f} ' f'saved {best_epoch_ago} epochs ago') self.load_weights(save_dir) # restore best model return history
def main(): batch_size = 512 epochs = 2000 train_steps_per_epoch = (all_data_len * 0.8) // batch_size dev_steps = (all_data_len * 0.2) // batch_size transform = TSVTaggingTransform() # 读取字典 load_vocabs(transform, save_dir) # 构建模型 model = build(transform) # 把训练数据和验证数据转为tf.data.Dataset格式 trn_data = transform.file_to_dataset(trn_path, batch_size=batch_size, shuffle=True, repeat=-1) dev_data = transform.file_to_dataset(dev_path, batch_size=batch_size, shuffle=True, repeat=-1) # tf.print('Count dataset size...') # train_steps_per_epoch = math.ceil(size_of_dataset(trn_data) / batch_size) # dev_steps = math.ceil(size_of_dataset(dev_data) / batch_size) # tf.print(f'train_steps_per_epoch: {train_steps_per_epoch}') # tf.print(f'dev_steps: {dev_steps}') # 设立指标,存储指标最优点的模型 metrics = "sparse_accuracy" monitor = f'val_{metrics}' checkpoint = tf.keras.callbacks.ModelCheckpoint( os.path.join(save_model, '0122_model_demo1.h5'), # verbose=1, monitor=monitor, save_best_only=True, mode='max', save_weights_only=True) # early_stopping = tf.keras.callbacks.EarlyStopping(monitor=monitor, patience=15) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=io_util.makedirs( io_util.path_join(save_model, 'logs_0122_demo1'))) callbacks = [checkpoint, tensorboard_callback] # 模型训练 # tf.debugging.set_log_device_placement(False) history = model.fit(trn_data, epochs=epochs, steps_per_epoch=train_steps_per_epoch, validation_data=dev_data, callbacks=callbacks, validation_steps=dev_steps, verbose=1)
def main(): batch_size = 512 epochs = 2000 transform = TSVTaggingTransform() # 读取字典 load_vocabs(transform, save_dir) # 构建模型 # strategy = tf.distribute.MirroredStrategy(devices=['/device:GPU:0', '/device:GPU:1']) # with strategy.scope(): # model = build(transform) model = build(transform) # 把训练数据和验证数据转为tf.data.Dataset格式 trn_data = transform.file_to_dataset(trn_path, batch_size=batch_size, shuffle=False, repeat=1) dev_data = transform.file_to_dataset(dev_path, batch_size=batch_size, shuffle=True, repeat=1) # 设立指标,存储指标最有点的模型 metrics = "Sparse_accuracy" monitor = f'val_{metrics}' checkpoint = tf.keras.callbacks.ModelCheckpoint( save_model, # verbose=1, monitor=monitor, save_best_only=True, mode='max', save_weights_only=False) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=io_util.makedirs(io_util.path_join(save_model, 'logs'))) earlystop = tf.keras.callbacks.EarlyStopping(monitor, mode='max', patience=50) csvlogger = tf.keras.callbacks.CSVLogger("0129train.csv") learning_rate_scheduler = tf.keras.callbacks.LearningRateScheduler( scheduler, verbose=1) callbacks = [ checkpoint, tensorboard_callback, earlystop, csvlogger, learning_rate_scheduler ] model.save('./model_struct/0129_best_model/', save_format='tf') # 模型训练 history = model.fit(trn_data, epochs=epochs, validation_data=dev_data, callbacks=callbacks, verbose=1)