コード例 #1
0
def save_checkpoint(ckpt_dir,
                    target,
                    step,
                    prefix='checkpoint_',
                    keep=1,
                    overwrite=False):
    """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str: path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.
    overwrite: bool: allow overwriting when writing a checkpoint.

  Returns:
    Filename of saved checkpoint.
  """
    # Write temporary checkpoint file.
    logging.info('Saving checkpoint at step: %s', step)
    ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
    ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
    gfile.makedirs(os.path.dirname(ckpt_path))

    logging.info('Writing to temporary checkpoint location: %s', ckpt_tmp_path)
    with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
        fp.write(serialization.to_bytes(target))

    # Rename once serialization and writing finished.
    gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite)
    logging.info('Saved checkpoint at %s', ckpt_path)

    # Remove old checkpoint files.
    base_path = os.path.join(ckpt_dir, f'{prefix}')
    checkpoint_files = natural_sort(gfile.glob(base_path + '*'))
    if len(checkpoint_files) > keep:
        old_ckpts = checkpoint_files[:-keep]
        for path in old_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    return ckpt_path
コード例 #2
0
def main(argv):
    """Main function."""
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    assert FLAGS.model_dir is not None, 'Please provide model_dir.'
    if not gfile.exists(FLAGS.model_dir):
        gfile.makedirs(FLAGS.model_dir)

    tf.enable_v2_behavior()

    # Prepare data.
    data_source = input_pipeline.SST2DataSource(min_freq=FLAGS.min_freq)

    # Create model.
    model = sst2_model.create_model(
        FLAGS.seed, FLAGS.batch_size, FLAGS.max_seq_len,
        dict(vocab_size=data_source.vocab_size,
             embedding_size=FLAGS.embedding_size,
             hidden_size=FLAGS.hidden_size,
             output_size=1,
             unk_idx=data_source.unk_idx,
             dropout=FLAGS.dropout,
             emb_dropout=FLAGS.emb_dropout,
             word_dropout_rate=FLAGS.word_dropout_rate))

    # Train the model.
    train_stats, model = train(model,
                               learning_rate=FLAGS.learning_rate,
                               num_epochs=FLAGS.num_epochs,
                               seed=FLAGS.seed,
                               model_dir=FLAGS.model_dir,
                               data_source=data_source,
                               batch_size=FLAGS.batch_size,
                               checkpoints_to_keep=FLAGS.checkpoints_to_keep,
                               l2_reg=FLAGS.l2_reg)

    # Evaluate the best model.
    valid_batches = input_pipeline.get_batches(data_source.valid_dataset,
                                               batch_size=FLAGS.batch_size)
    metrics = evaluate(model, valid_batches)
    logging.info('Best validation accuracy: %.2f', metrics['acc'])
コード例 #3
0
def save_model(filename: str, model: nn.Module) -> None:
    gfile.makedirs(os.path.dirname(filename))
    with gfile.GFile(filename, "wb") as fp:
        fp.write(serialization.to_bytes(model))