def _train_bert_multitask_keras_model(
        train_dataset: tf.data.Dataset,
        eval_dataset: tf.data.Dataset,
        model: tf.keras.Model,
        params: BaseParams,
        mirrored_strategy: tf.distribute.MirroredStrategy = None):
    # can't save whole model with model subclassing api due to tf bug
    # see: https://github.com/tensorflow/tensorflow/issues/42741
    # https://github.com/tensorflow/tensorflow/issues/40366
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(params.ckpt_dir, 'model'),
        save_weights_only=True,
        monitor='val_mean_acc',
        mode='auto',
        save_best_only=True)

    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=params.ckpt_dir)
    if mirrored_strategy is not None:
        with mirrored_strategy.scope():
            model.fit(
                x=train_dataset.repeat(),
                validation_data=eval_dataset,
                epochs=params.train_epoch,
                callbacks=[model_checkpoint_callback, tensorboard_callback],
                steps_per_epoch=params.train_steps_per_epoch)
    else:
        model.fit(x=train_dataset.repeat(),
                  validation_data=eval_dataset,
                  epochs=params.train_epoch,
                  callbacks=[model_checkpoint_callback, tensorboard_callback],
                  steps_per_epoch=params.train_steps_per_epoch)
    model.summary()
Beispiel #2
0
def _train_bert_multitask_keras_model(
        train_dataset: tf.data.Dataset,
        eval_dataset: tf.data.Dataset,
        model: tf.keras.Model,
        params: BaseParams,
        mirrored_strategy: tf.distribute.MirroredStrategy = None):
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(params.ckpt_dir, 'model'),
        save_weights_only=True,
        monitor='val_acc',
        mode='auto',
        save_best_only=False)

    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=params.ckpt_dir)

    with mirrored_strategy.scope():
        model.compile()
        model.fit(x=train_dataset,
                  validation_data=eval_dataset,
                  epochs=params.train_epoch,
                  callbacks=[model_checkpoint_callback, tensorboard_callback],
                  steps_per_epoch=params.train_steps_per_epoch)
    model.summary()
def create_keras_model(mirrored_strategy: tf.distribute.MirroredStrategy,
                       params: BaseParams,
                       mode='train',
                       inputs_to_build_model=None,
                       model=None):
    """init model in various mode

    train: model will be loaded from huggingface
    resume: model will be loaded from params.ckpt_dir, if params.ckpt_dir dose not contain valid checkpoint, then load from huggingface
    transfer: model will be loaded from params.init_checkpoint, the correspongding path should contain checkpoints saved using bert-multitask-learning
    predict: model will be loaded from params.ckpt_dir except optimizers' states
    eval: model will be loaded from params.ckpt_dir except optimizers' states, model will be compiled

    Args:
        mirrored_strategy (tf.distribute.MirroredStrategy): mirrored strategy
        params (BaseParams): params
        mode (str, optional): Mode, see above explaination. Defaults to 'train'.
        inputs_to_build_model (Dict, optional): A batch of data. Defaults to None.
        model (Model, optional): Keras model. Defaults to None.

    Returns:
        model: loaded model
    """
    def _get_model_wrapper(params, mode, inputs_to_build_model, model):
        if model is None:
            model = BertMultiTask(params)
            # model.run_eagerly = True
        if mode == 'resume':
            model.compile()
            # build training graph
            # model.train_step(inputs_to_build_model)
            _ = model(inputs_to_build_model,
                      mode=tf.estimator.ModeKeys.PREDICT)
            # load ALL vars including optimizers' states
            try:
                model.load_weights(os.path.join(params.ckpt_dir, 'model'),
                                   skip_mismatch=False)
            except TFNotFoundError:
                LOGGER.warn('Not resuming since no mathcing ckpt found')
        elif mode == 'transfer':
            # build graph without optimizers' states
            # calling compile again should reset optimizers' states but we're playing safe here
            _ = model(inputs_to_build_model,
                      mode=tf.estimator.ModeKeys.PREDICT)
            # load weights without loading optimizers' vars
            model.load_weights(os.path.join(params.init_checkpoint, 'model'))
            # compile again
            model.compile()
        elif mode == 'predict':
            _ = model(inputs_to_build_model,
                      mode=tf.estimator.ModeKeys.PREDICT)
            # load weights without loading optimizers' vars
            model.load_weights(os.path.join(params.ckpt_dir, 'model'))
        elif mode == 'eval':
            _ = model(inputs_to_build_model,
                      mode=tf.estimator.ModeKeys.PREDICT)
            # load weights without loading optimizers' vars
            model.load_weights(os.path.join(params.ckpt_dir, 'model'))
            model.compile()
        else:
            model.compile()

        return model

    if mirrored_strategy is not None:
        with mirrored_strategy.scope():
            model = _get_model_wrapper(params, mode, inputs_to_build_model,
                                       model)
    else:
        model = _get_model_wrapper(params, mode, inputs_to_build_model, model)
    return model
Beispiel #4
0
def main(strategy: tf.distribute.MirroredStrategy, global_step: tf.Tensor,
         train_writer: tf.summary.SummaryWriter,
         eval_writer: tf.summary.SummaryWriter, train_batch_size: int,
         eval_batch_size: int, job_dir: str, dataset_dir: str,
         dataset_filename: str, num_epochs: int, summary_steps: int,
         log_steps: int, dataset_spec: DatasetSpec, model: tf.keras.Model,
         loss_fn: tf.keras.losses.Loss,
         optimizer: tf.keras.optimizers.Optimizer):
    # Define metrics
    eval_metric = tf.keras.metrics.CategoricalAccuracy()
    best_metric = tf.Variable(eval_metric.result())

    # Define training loop

    @distributed_run(strategy)
    def train_step(inputs):
        with tf.GradientTape() as tape:
            images, labels = inputs

            logits = model(images)

            cross_entropy = loss_fn(labels, logits)
            loss = tf.reduce_sum(cross_entropy) / train_batch_size

            gradients = tape.gradient(loss, model.variables)
            optimizer.apply_gradients(zip(gradients, model.variables))

            if global_step % summary_steps == 0:
                tf.summary.scalar('loss', loss, step=global_step)

            return loss

    @distributed_run(strategy)
    def eval_step(inputs, metric):
        images, labels = inputs

        logits = model(images)

        metric.update_state(labels, logits)

    # Build input pipeline
    train_reader = Reader(dataset_dir, dataset_filename, split=Split.Train)
    test_reader = Reader(dataset_dir, dataset_filename, split=Split.Test)
    train_dataset = train_reader.read()
    test_dataset = test_reader.read()

    @unpack_dict
    def map_fn(_id, image, label):
        return tf.cast(image, tf.float32) / 255., label

    train_dataset = dataset_spec.parse(train_dataset).batch(
        train_batch_size).map(map_fn)
    test_dataset = dataset_spec.parse(test_dataset).batch(eval_batch_size).map(
        map_fn)

    #################
    # Training loop #
    #################
    # Define checkpoint
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     model=model,
                                     global_step=global_step,
                                     best_metric=best_metric)
    # Restore the model
    checkpoint_dir = job_dir
    checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

    # Prepare dataset for distributed run
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    test_dataset = strategy.experimental_distribute_dataset(test_dataset)

    with CheckpointHandler(checkpoint, checkpoint_prefix):
        for epoch in range(num_epochs):
            print('---------- Epoch: {} ----------'.format(epoch + 1))

            print('Starting training for epoch: {}'.format(epoch + 1))
            with train_writer.as_default():
                for inputs in tqdm(train_dataset,
                                   initial=global_step.numpy(),
                                   desc='Training',
                                   unit=' steps'):
                    per_replica_losses = train_step(inputs)
                    mean_loss = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                                per_replica_losses, None)

                    if global_step.numpy() % log_steps == 0:
                        print('Loss: {}'.format(mean_loss.numpy()))

                    # Increment global step
                    global_step.assign_add(1)

            print('Starting evaluation for epoch: {}'.format(epoch + 1))

            with eval_writer.as_default():
                for inputs in tqdm(test_dataset, desc='Evaluating'):
                    eval_step(inputs, eval_metric)

                accuracy = eval_metric.result()
                print('Accuracy: {}'.format(accuracy.numpy()))
                tf.summary.scalar('accuracy', accuracy, step=global_step)

                if accuracy >= best_metric:
                    checkpoint.save(file_prefix=checkpoint_prefix + '-best')
                    print('The best model saved: {} is higher than {}'.format(
                        accuracy.numpy(), best_metric.numpy()))
                    best_metric.assign(accuracy)

            eval_metric.reset_states()
Beispiel #5
0
def _create_keras_model(mirrored_strategy: tf.distribute.MirroredStrategy,
                        params: BaseParams):
    with mirrored_strategy.scope():
        model = BertMultiTask(params)
    return model