def test_history_per_batch_plot(outdir):
    hist = HistoryPerBatch()
    hist.params = {}
    hist.params['metrics'] = ['loss', 'val_loss']
    hist.on_train_begin(0)
    path_cb = str(outdir.join("callback_plot.png"))
    plot_cb = hist.plot_callback(fname=path_cb)
    n = 50
    mean = 1/np.arange(1, n+1)
    std = 1/np.arange(1, n+1)
    for e in range(n):
        hist.on_epoch_begin(e)
        for b in range(100):
            hist.on_batch_begin(b)
            hist.on_batch_end(b, logs={'loss': float(np.random.normal(mean[e], std[e], 1))})
        hist.on_epoch_end(e, logs={'val_loss': float(np.random.normal(mean[e], std[e], 1))})

    plot_cb.on_epoch_end(e)
    fig, axes = hist.plot()
    path1 = str(outdir.join("callback_history.png"))
    fig.savefig(path1)
    path2 = str(outdir.join("callback_history2.png"))
    hist.plot(save_as=path2)
    filecmp.cmp(path1, path2)
    filecmp.cmp(path_cb, path2)