def save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str: path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. overwrite: bool: allow overwriting when writing a checkpoint. Returns: Filename of saved checkpoint. """ # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) logging.info('Writing to temporary checkpoint location: %s', ckpt_tmp_path) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite) logging.info('Saved checkpoint at %s', ckpt_path) # Remove old checkpoint files. base_path = os.path.join(ckpt_dir, f'{prefix}') checkpoint_files = natural_sort(gfile.glob(base_path + '*')) if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path
def main(argv): """Main function.""" if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') assert FLAGS.model_dir is not None, 'Please provide model_dir.' if not gfile.exists(FLAGS.model_dir): gfile.makedirs(FLAGS.model_dir) tf.enable_v2_behavior() # Prepare data. data_source = input_pipeline.SST2DataSource(min_freq=FLAGS.min_freq) # Create model. model = sst2_model.create_model( FLAGS.seed, FLAGS.batch_size, FLAGS.max_seq_len, dict(vocab_size=data_source.vocab_size, embedding_size=FLAGS.embedding_size, hidden_size=FLAGS.hidden_size, output_size=1, unk_idx=data_source.unk_idx, dropout=FLAGS.dropout, emb_dropout=FLAGS.emb_dropout, word_dropout_rate=FLAGS.word_dropout_rate)) # Train the model. train_stats, model = train(model, learning_rate=FLAGS.learning_rate, num_epochs=FLAGS.num_epochs, seed=FLAGS.seed, model_dir=FLAGS.model_dir, data_source=data_source, batch_size=FLAGS.batch_size, checkpoints_to_keep=FLAGS.checkpoints_to_keep, l2_reg=FLAGS.l2_reg) # Evaluate the best model. valid_batches = input_pipeline.get_batches(data_source.valid_dataset, batch_size=FLAGS.batch_size) metrics = evaluate(model, valid_batches) logging.info('Best validation accuracy: %.2f', metrics['acc'])
def save_model(filename: str, model: nn.Module) -> None: gfile.makedirs(os.path.dirname(filename)) with gfile.GFile(filename, "wb") as fp: fp.write(serialization.to_bytes(model))