def _train_bert_multitask_keras_model( train_dataset: tf.data.Dataset, eval_dataset: tf.data.Dataset, model: tf.keras.Model, params: BaseParams, mirrored_strategy: tf.distribute.MirroredStrategy = None): # can't save whole model with model subclassing api due to tf bug # see: https://github.com/tensorflow/tensorflow/issues/42741 # https://github.com/tensorflow/tensorflow/issues/40366 model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=os.path.join(params.ckpt_dir, 'model'), save_weights_only=True, monitor='val_mean_acc', mode='auto', save_best_only=True) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=params.ckpt_dir) if mirrored_strategy is not None: with mirrored_strategy.scope(): model.fit( x=train_dataset.repeat(), validation_data=eval_dataset, epochs=params.train_epoch, callbacks=[model_checkpoint_callback, tensorboard_callback], steps_per_epoch=params.train_steps_per_epoch) else: model.fit(x=train_dataset.repeat(), validation_data=eval_dataset, epochs=params.train_epoch, callbacks=[model_checkpoint_callback, tensorboard_callback], steps_per_epoch=params.train_steps_per_epoch) model.summary()
def _train_bert_multitask_keras_model( train_dataset: tf.data.Dataset, eval_dataset: tf.data.Dataset, model: tf.keras.Model, params: BaseParams, mirrored_strategy: tf.distribute.MirroredStrategy = None): model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=os.path.join(params.ckpt_dir, 'model'), save_weights_only=True, monitor='val_acc', mode='auto', save_best_only=False) tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=params.ckpt_dir) with mirrored_strategy.scope(): model.compile() model.fit(x=train_dataset, validation_data=eval_dataset, epochs=params.train_epoch, callbacks=[model_checkpoint_callback, tensorboard_callback], steps_per_epoch=params.train_steps_per_epoch) model.summary()
def create_keras_model(mirrored_strategy: tf.distribute.MirroredStrategy, params: BaseParams, mode='train', inputs_to_build_model=None, model=None): """init model in various mode train: model will be loaded from huggingface resume: model will be loaded from params.ckpt_dir, if params.ckpt_dir dose not contain valid checkpoint, then load from huggingface transfer: model will be loaded from params.init_checkpoint, the correspongding path should contain checkpoints saved using bert-multitask-learning predict: model will be loaded from params.ckpt_dir except optimizers' states eval: model will be loaded from params.ckpt_dir except optimizers' states, model will be compiled Args: mirrored_strategy (tf.distribute.MirroredStrategy): mirrored strategy params (BaseParams): params mode (str, optional): Mode, see above explaination. Defaults to 'train'. inputs_to_build_model (Dict, optional): A batch of data. Defaults to None. model (Model, optional): Keras model. Defaults to None. Returns: model: loaded model """ def _get_model_wrapper(params, mode, inputs_to_build_model, model): if model is None: model = BertMultiTask(params) # model.run_eagerly = True if mode == 'resume': model.compile() # build training graph # model.train_step(inputs_to_build_model) _ = model(inputs_to_build_model, mode=tf.estimator.ModeKeys.PREDICT) # load ALL vars including optimizers' states try: model.load_weights(os.path.join(params.ckpt_dir, 'model'), skip_mismatch=False) except TFNotFoundError: LOGGER.warn('Not resuming since no mathcing ckpt found') elif mode == 'transfer': # build graph without optimizers' states # calling compile again should reset optimizers' states but we're playing safe here _ = model(inputs_to_build_model, mode=tf.estimator.ModeKeys.PREDICT) # load weights without loading optimizers' vars model.load_weights(os.path.join(params.init_checkpoint, 'model')) # compile again model.compile() elif mode == 'predict': _ = model(inputs_to_build_model, mode=tf.estimator.ModeKeys.PREDICT) # load weights without loading optimizers' vars model.load_weights(os.path.join(params.ckpt_dir, 'model')) elif mode == 'eval': _ = model(inputs_to_build_model, mode=tf.estimator.ModeKeys.PREDICT) # load weights without loading optimizers' vars model.load_weights(os.path.join(params.ckpt_dir, 'model')) model.compile() else: model.compile() return model if mirrored_strategy is not None: with mirrored_strategy.scope(): model = _get_model_wrapper(params, mode, inputs_to_build_model, model) else: model = _get_model_wrapper(params, mode, inputs_to_build_model, model) return model
def main(strategy: tf.distribute.MirroredStrategy, global_step: tf.Tensor, train_writer: tf.summary.SummaryWriter, eval_writer: tf.summary.SummaryWriter, train_batch_size: int, eval_batch_size: int, job_dir: str, dataset_dir: str, dataset_filename: str, num_epochs: int, summary_steps: int, log_steps: int, dataset_spec: DatasetSpec, model: tf.keras.Model, loss_fn: tf.keras.losses.Loss, optimizer: tf.keras.optimizers.Optimizer): # Define metrics eval_metric = tf.keras.metrics.CategoricalAccuracy() best_metric = tf.Variable(eval_metric.result()) # Define training loop @distributed_run(strategy) def train_step(inputs): with tf.GradientTape() as tape: images, labels = inputs logits = model(images) cross_entropy = loss_fn(labels, logits) loss = tf.reduce_sum(cross_entropy) / train_batch_size gradients = tape.gradient(loss, model.variables) optimizer.apply_gradients(zip(gradients, model.variables)) if global_step % summary_steps == 0: tf.summary.scalar('loss', loss, step=global_step) return loss @distributed_run(strategy) def eval_step(inputs, metric): images, labels = inputs logits = model(images) metric.update_state(labels, logits) # Build input pipeline train_reader = Reader(dataset_dir, dataset_filename, split=Split.Train) test_reader = Reader(dataset_dir, dataset_filename, split=Split.Test) train_dataset = train_reader.read() test_dataset = test_reader.read() @unpack_dict def map_fn(_id, image, label): return tf.cast(image, tf.float32) / 255., label train_dataset = dataset_spec.parse(train_dataset).batch( train_batch_size).map(map_fn) test_dataset = dataset_spec.parse(test_dataset).batch(eval_batch_size).map( map_fn) ################# # Training loop # ################# # Define checkpoint checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, global_step=global_step, best_metric=best_metric) # Restore the model checkpoint_dir = job_dir checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) # Prepare dataset for distributed run train_dataset = strategy.experimental_distribute_dataset(train_dataset) test_dataset = strategy.experimental_distribute_dataset(test_dataset) with CheckpointHandler(checkpoint, checkpoint_prefix): for epoch in range(num_epochs): print('---------- Epoch: {} ----------'.format(epoch + 1)) print('Starting training for epoch: {}'.format(epoch + 1)) with train_writer.as_default(): for inputs in tqdm(train_dataset, initial=global_step.numpy(), desc='Training', unit=' steps'): per_replica_losses = train_step(inputs) mean_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, None) if global_step.numpy() % log_steps == 0: print('Loss: {}'.format(mean_loss.numpy())) # Increment global step global_step.assign_add(1) print('Starting evaluation for epoch: {}'.format(epoch + 1)) with eval_writer.as_default(): for inputs in tqdm(test_dataset, desc='Evaluating'): eval_step(inputs, eval_metric) accuracy = eval_metric.result() print('Accuracy: {}'.format(accuracy.numpy())) tf.summary.scalar('accuracy', accuracy, step=global_step) if accuracy >= best_metric: checkpoint.save(file_prefix=checkpoint_prefix + '-best') print('The best model saved: {} is higher than {}'.format( accuracy.numpy(), best_metric.numpy())) best_metric.assign(accuracy) eval_metric.reset_states()
def _create_keras_model(mirrored_strategy: tf.distribute.MirroredStrategy, params: BaseParams): with mirrored_strategy.scope(): model = BertMultiTask(params) return model