Ejemplo n.º 1
0
def test_invalid_save_freq(tmp_path):
    test_model_filepath = str(tmp_path / "test_model.h5")
    save_freq = "invalid_save_freq"
    with pytest.raises(ValueError, match="Unrecognized save_freq"):
        AverageModelCheckpoint(update_weights=True,
                               filepath=test_model_filepath,
                               save_freq=save_freq)
def test_model_file_creation(tmp_path):
    test_model_filepath = str(tmp_path / "test_model.h5")
    x, y, model = get_data_and_model()
    avg_model_ckpt = AverageModelCheckpoint(
        update_weights=True, filepath=test_model_filepath
    )
    model.fit(x, y, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[avg_model_ckpt])
    assert os.path.exists(test_model_filepath)
def test_compatibility_with_some_opts_only(tmp_path):
    test_model_filepath = str(tmp_path / "test_model.h5")
    x, y, model = get_data_and_model(optimizer="rmsprop")
    avg_model_ckpt = AverageModelCheckpoint(
        update_weights=True, filepath=test_model_filepath
    )
    with pytest.raises(
        TypeError,
        match="AverageModelCheckpoint is only used when trainingwith"
        " MovingAverage or StochasticAverage",
    ):
        model.fit(
            x, y, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[avg_model_ckpt]
        )
Ejemplo n.º 4
0
def test_save_best_only(tmp_path):
    test_model_filepath = str(tmp_path / "test_model.h5")
    x, y, model = get_data_and_model()
    save_best_only = True
    avg_model_ckpt = AverageModelCheckpoint(update_weights=True,
                                            filepath=test_model_filepath,
                                            save_best_only=save_best_only)
    model.fit(
        x,
        y,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=(x, y),
        callbacks=[avg_model_ckpt],
    )
    assert os.path.exists(test_model_filepath)
Ejemplo n.º 5
0
def test_loss_scale_optimizer(tmp_path):
    test_model_filepath = str(tmp_path / "test_model.{epoch:02d}.h5")
    moving_avg = MovingAverage(tf.keras.optimizers.SGD(lr=2.0),
                               average_decay=0.5)
    optimizer = tf.keras.mixed_precision.LossScaleOptimizer(moving_avg)
    x, y, model = get_data_and_model(optimizer)
    save_freq = "epoch"
    avg_model_ckpt = AverageModelCheckpoint(update_weights=False,
                                            filepath=test_model_filepath,
                                            save_freq=save_freq)
    model.fit(
        x,
        y,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=(x, y),
        callbacks=[avg_model_ckpt],
    )
    assert not os.path.exists(test_model_filepath)
Ejemplo n.º 6
0
def test_metric_unavailable(tmp_path):
    test_model_filepath = str(tmp_path / "test_model.h5")
    x, y, model = get_data_and_model()
    monitor = "unknown"
    avg_model_ckpt = AverageModelCheckpoint(
        update_weights=False,
        filepath=test_model_filepath,
        monitor=monitor,
        save_best_only=True,
    )
    model.fit(
        x,
        y,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=(x, y),
        callbacks=[avg_model_ckpt],
    )
    assert not os.path.exists(test_model_filepath)
Ejemplo n.º 7
0
def test_save_freq(tmp_path):
    test_filepath = str(tmp_path / "test_model.{epoch:02d}.h5")
    x, y, model = get_data_and_model()
    save_freq = "epoch"
    avg_model_ckpt = AverageModelCheckpoint(update_weights=False,
                                            filepath=test_filepath,
                                            save_freq=save_freq)
    model.fit(
        x,
        y,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=(x, y),
        callbacks=[avg_model_ckpt],
    )
    assert os.path.exists(test_filepath.format(epoch=1))
    assert os.path.exists(test_filepath.format(epoch=2))
    assert os.path.exists(test_filepath.format(epoch=3))
    assert os.path.exists(test_filepath.format(epoch=4))
    assert os.path.exists(test_filepath.format(epoch=5))
Ejemplo n.º 8
0
def get_callbacks(params, val_dataset=None):
    """Get callbacks for given params."""
    if params['moving_average_decay']:
        avg_callback = AverageModelCheckpoint(filepath=os.path.join(
            params['model_dir'], 'emackpt-{epoch:d}'),
                                              verbose=params['verbose'],
                                              save_freq=params['save_freq'],
                                              save_weights_only=True,
                                              update_weights=False)
        callbacks = [avg_callback]
    else:
        ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
            os.path.join(params['model_dir'], 'ckpt-{epoch:d}'),
            verbose=params['verbose'],
            save_freq=params['save_freq'],
            save_weights_only=True)
        callbacks = [ckpt_callback]
    if params['model_optimizations'] and 'prune' in params[
            'model_optimizations']:
        prune_callback = UpdatePruningStep()
        prune_summaries = PruningSummaries(
            log_dir=params['model_dir'],
            update_freq=params['steps_per_execution'],
            profile_batch=2 if params['profile'] else 0)
        callbacks += [prune_callback, prune_summaries]
    else:
        tb_callback = tf.keras.callbacks.TensorBoard(
            log_dir=params['model_dir'],
            update_freq=params['steps_per_execution'],
            profile_batch=2 if params['profile'] else 0)
        callbacks.append(tb_callback)
    if params.get('sample_image', None):
        display_callback = DisplayCallback(params.get('sample_image', None),
                                           params['model_dir'],
                                           params['img_summary_steps'])
        callbacks.append(display_callback)
    if (params.get('map_freq', None) and val_dataset
            and params['strategy'] != 'tpu'):
        coco_callback = COCOCallback(val_dataset, params['map_freq'])
        callbacks.append(coco_callback)
    return callbacks
Ejemplo n.º 9
0
def test_mode_min(tmp_path):
    test_model_filepath = str(tmp_path / "test_model.h5")
    x, y, model = get_data_and_model()
    monitor = "val_loss"
    save_best_only = False
    mode = "min"
    avg_model_ckpt = AverageModelCheckpoint(
        update_weights=True,
        filepath=test_model_filepath,
        monitor=monitor,
        save_best_only=save_best_only,
        mode=mode,
    )
    model.fit(
        x,
        y,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=(x, y),
        callbacks=[avg_model_ckpt],
    )
    assert os.path.exists(test_model_filepath)
Ejemplo n.º 10
0
def get_callbacks(params, val_dataset):
    """Get callbacks for given params."""
    if False:
        from tensorflow_addons.callbacks import AverageModelCheckpoint
        avg_callback = AverageModelCheckpoint(filepath=os.path.join(
            params['model_dir'], 'ckpt'),
                                              verbose=1,
                                              save_weights_only=True,
                                              update_weights=True)
        callbacks = [avg_callback]
    else:
        ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
            os.path.join(params['model_dir'], 'ckpt'),
            verbose=1,
            save_weights_only=True)
        callbacks = [ckpt_callback]
    if params['model_optimizations'] and 'prune' in params[
            'model_optimizations']:
        prune_callback = UpdatePruningStep()
        prune_summaries = PruningSummaries(
            log_dir=params['model_dir'],
            update_freq=params['iterations_per_loop'],
            profile_batch=2 if params['profile'] else 0)
        callbacks += [prune_callback, prune_summaries]
    else:
        tb_callback = tf.keras.callbacks.TensorBoard(
            log_dir=params['model_dir'],
            update_freq=params['iterations_per_loop'],
            profile_batch=2 if params['profile'] else 0)
        callbacks.append(tb_callback)
    if params.get('sample_image', None):
        display_callback = DisplayCallback(params.get('sample_image', None),
                                           params['model_dir'],
                                           params['img_summary_steps'])
        callbacks.append(display_callback)
    if params.get('map_freq', None):
        coco_callback = COCOCallback(val_dataset, params['map_freq'])
        callbacks.append(coco_callback)
    return callbacks