def test_invalid_save_freq(tmp_path): test_model_filepath = str(tmp_path / "test_model.h5") save_freq = "invalid_save_freq" with pytest.raises(ValueError, match="Unrecognized save_freq"): AverageModelCheckpoint(update_weights=True, filepath=test_model_filepath, save_freq=save_freq)
def test_model_file_creation(tmp_path): test_model_filepath = str(tmp_path / "test_model.h5") x, y, model = get_data_and_model() avg_model_ckpt = AverageModelCheckpoint( update_weights=True, filepath=test_model_filepath ) model.fit(x, y, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[avg_model_ckpt]) assert os.path.exists(test_model_filepath)
def test_compatibility_with_some_opts_only(tmp_path): test_model_filepath = str(tmp_path / "test_model.h5") x, y, model = get_data_and_model(optimizer="rmsprop") avg_model_ckpt = AverageModelCheckpoint( update_weights=True, filepath=test_model_filepath ) with pytest.raises( TypeError, match="AverageModelCheckpoint is only used when trainingwith" " MovingAverage or StochasticAverage", ): model.fit( x, y, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[avg_model_ckpt] )
def test_save_best_only(tmp_path): test_model_filepath = str(tmp_path / "test_model.h5") x, y, model = get_data_and_model() save_best_only = True avg_model_ckpt = AverageModelCheckpoint(update_weights=True, filepath=test_model_filepath, save_best_only=save_best_only) model.fit( x, y, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(x, y), callbacks=[avg_model_ckpt], ) assert os.path.exists(test_model_filepath)
def test_loss_scale_optimizer(tmp_path): test_model_filepath = str(tmp_path / "test_model.{epoch:02d}.h5") moving_avg = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5) optimizer = tf.keras.mixed_precision.LossScaleOptimizer(moving_avg) x, y, model = get_data_and_model(optimizer) save_freq = "epoch" avg_model_ckpt = AverageModelCheckpoint(update_weights=False, filepath=test_model_filepath, save_freq=save_freq) model.fit( x, y, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(x, y), callbacks=[avg_model_ckpt], ) assert not os.path.exists(test_model_filepath)
def test_metric_unavailable(tmp_path): test_model_filepath = str(tmp_path / "test_model.h5") x, y, model = get_data_and_model() monitor = "unknown" avg_model_ckpt = AverageModelCheckpoint( update_weights=False, filepath=test_model_filepath, monitor=monitor, save_best_only=True, ) model.fit( x, y, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(x, y), callbacks=[avg_model_ckpt], ) assert not os.path.exists(test_model_filepath)
def test_save_freq(tmp_path): test_filepath = str(tmp_path / "test_model.{epoch:02d}.h5") x, y, model = get_data_and_model() save_freq = "epoch" avg_model_ckpt = AverageModelCheckpoint(update_weights=False, filepath=test_filepath, save_freq=save_freq) model.fit( x, y, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(x, y), callbacks=[avg_model_ckpt], ) assert os.path.exists(test_filepath.format(epoch=1)) assert os.path.exists(test_filepath.format(epoch=2)) assert os.path.exists(test_filepath.format(epoch=3)) assert os.path.exists(test_filepath.format(epoch=4)) assert os.path.exists(test_filepath.format(epoch=5))
def get_callbacks(params, val_dataset=None): """Get callbacks for given params.""" if params['moving_average_decay']: avg_callback = AverageModelCheckpoint(filepath=os.path.join( params['model_dir'], 'emackpt-{epoch:d}'), verbose=params['verbose'], save_freq=params['save_freq'], save_weights_only=True, update_weights=False) callbacks = [avg_callback] else: ckpt_callback = tf.keras.callbacks.ModelCheckpoint( os.path.join(params['model_dir'], 'ckpt-{epoch:d}'), verbose=params['verbose'], save_freq=params['save_freq'], save_weights_only=True) callbacks = [ckpt_callback] if params['model_optimizations'] and 'prune' in params[ 'model_optimizations']: prune_callback = UpdatePruningStep() prune_summaries = PruningSummaries( log_dir=params['model_dir'], update_freq=params['steps_per_execution'], profile_batch=2 if params['profile'] else 0) callbacks += [prune_callback, prune_summaries] else: tb_callback = tf.keras.callbacks.TensorBoard( log_dir=params['model_dir'], update_freq=params['steps_per_execution'], profile_batch=2 if params['profile'] else 0) callbacks.append(tb_callback) if params.get('sample_image', None): display_callback = DisplayCallback(params.get('sample_image', None), params['model_dir'], params['img_summary_steps']) callbacks.append(display_callback) if (params.get('map_freq', None) and val_dataset and params['strategy'] != 'tpu'): coco_callback = COCOCallback(val_dataset, params['map_freq']) callbacks.append(coco_callback) return callbacks
def test_mode_min(tmp_path): test_model_filepath = str(tmp_path / "test_model.h5") x, y, model = get_data_and_model() monitor = "val_loss" save_best_only = False mode = "min" avg_model_ckpt = AverageModelCheckpoint( update_weights=True, filepath=test_model_filepath, monitor=monitor, save_best_only=save_best_only, mode=mode, ) model.fit( x, y, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(x, y), callbacks=[avg_model_ckpt], ) assert os.path.exists(test_model_filepath)
def get_callbacks(params, val_dataset): """Get callbacks for given params.""" if False: from tensorflow_addons.callbacks import AverageModelCheckpoint avg_callback = AverageModelCheckpoint(filepath=os.path.join( params['model_dir'], 'ckpt'), verbose=1, save_weights_only=True, update_weights=True) callbacks = [avg_callback] else: ckpt_callback = tf.keras.callbacks.ModelCheckpoint( os.path.join(params['model_dir'], 'ckpt'), verbose=1, save_weights_only=True) callbacks = [ckpt_callback] if params['model_optimizations'] and 'prune' in params[ 'model_optimizations']: prune_callback = UpdatePruningStep() prune_summaries = PruningSummaries( log_dir=params['model_dir'], update_freq=params['iterations_per_loop'], profile_batch=2 if params['profile'] else 0) callbacks += [prune_callback, prune_summaries] else: tb_callback = tf.keras.callbacks.TensorBoard( log_dir=params['model_dir'], update_freq=params['iterations_per_loop'], profile_batch=2 if params['profile'] else 0) callbacks.append(tb_callback) if params.get('sample_image', None): display_callback = DisplayCallback(params.get('sample_image', None), params['model_dir'], params['img_summary_steps']) callbacks.append(display_callback) if params.get('map_freq', None): coco_callback = COCOCallback(val_dataset, params['map_freq']) callbacks.append(coco_callback) return callbacks