def test_history_per_batch_plot(outdir):
    hist = HistoryPerBatch()
    hist.params = {}
    hist.params['metrics'] = ['loss', 'val_loss']
    hist.on_train_begin(0)
    path_cb = str(outdir.join("callback_plot.png"))
    plot_cb = hist.plot_callback(fname=path_cb)
    n = 50
    mean = 1/np.arange(1, n+1)
    std = 1/np.arange(1, n+1)
    for e in range(n):
        hist.on_epoch_begin(e)
        for b in range(100):
            hist.on_batch_begin(b)
            hist.on_batch_end(b, logs={'loss': float(np.random.normal(mean[e], std[e], 1))})
        hist.on_epoch_end(e, logs={'val_loss': float(np.random.normal(mean[e], std[e], 1))})

    plot_cb.on_epoch_end(e)
    fig, axes = hist.plot()
    path1 = str(outdir.join("callback_history.png"))
    fig.savefig(path1)
    path2 = str(outdir.join("callback_history2.png"))
    hist.plot(save_as=path2)
    filecmp.cmp(path1, path2)
    filecmp.cmp(path_cb, path2)
def test_history_per_batch(tmpdir):
    hist = HistoryPerBatch(str(tmpdir))

    hist.params = {}
    hist.params['metrics'] = ['loss', 'val_loss']
    hist.on_epoch_begin(0)
    losses = [[]]
    for i in range(5):
        loss = float(np.random.sample(1))
        hist.on_batch_end(i, logs={'loss': loss})
        losses[-1].append(loss)

    hist.on_epoch_end(0, logs={'loss': 1, 'val_loss': 2})

    losses.append([])
    hist.on_epoch_begin(1)
    for i in range(5):
        loss = float(np.random.sample(1))
        hist.on_batch_end(i, logs={'loss': loss})
        losses[-1].append(loss)
    hist.on_epoch_end(1, logs={'loss': 1, 'val_loss': 2})

    losses.append([])
    hist.on_epoch_begin(2)
    for i in range(5):
        loss = float(np.random.sample(1))
        hist.on_batch_end(i, logs={'loss': loss})
        losses[-1].append(loss)

    hist.on_epoch_end(2, logs={'loss': 1, 'val_loss': 2})

    with pytest.warns(DeprecationWarning):
        assert hist.history['loss'] == losses

    assert hist.epoch_history['loss'] == [1, 1, 1]
    assert hist.epoch_history['val_loss'] == [2, 2, 2]

    hist.on_train_end()
    assert tmpdir.join("history.json").exists()
    assert tmpdir.join("history.png").exists()