def test_history_iter():
    '''Test History's __iter__ method.'''
    max_history = 3
    test = {chr(97 + i): (i + 1, ) for i in range(26)}
    history = utils.History(max_history, **test)
    for key, history_key in zip(test, history):
        assert key == history_key
def test_history_build_multistate_version():
    '''Test history's build_multistate method.'''
    test_a = npr.rand(5, 5)
    test_b = npr.rand(1)
    history = utils.History(1, test_a=(5, 5), test_b=(1, ))
    history.append(test_a=test_a, test_b=test_b)
    hist_a, hist_b = zip(*history.build_multistate())
    assert np.all(hist_a == test_a.ravel())
    assert np.all(hist_b == test_b.ravel())
def test_history_append():
    '''Test History's append method.'''
    max_history = 3
    history = utils.History(max_history, test1d=(5, ), test2d=(5, 5))
    test1d = np.arange(max_history * 5).reshape((max_history, 5))
    test2d = np.arange(max_history * 25).reshape((max_history, 5, 5))
    for i in range(max_history):
        history.append(test1d=test1d[-(i + 1)], test2d=test2d[-(i + 1)])
    assert np.all(history['test1d'] == test1d)
    assert np.all(history['test2d'] == test2d)
def test_history_reset():
    '''Test History's reset method.'''
    max_history = 3
    history = utils.History(max_history, test1d=(5, ), test2d=(5, 5))
    test1d = np.arange(max_history * 5).reshape((max_history, 5))
    test2d = np.arange(max_history * 25).reshape((max_history, 5, 5))
    for i in range(max_history):
        history.append(test1d=test1d[-(i + 1)], test2d=test2d[-(i + 1)])
    history.reset()
    assert np.all(history['test1d'] == 0)
    assert np.all(history['test2d'] == 0)
    history.reset(test1d=np.ones((5, )), test2d=np.ones((5, 5)))
    assert np.all(history['test1d'] == 1)
    assert np.all(history['test2d'] == 1)
def create_random_history(shape=(5, 5), max_history=10):
    '''
    Create a history object with random elements.

    :param max_history: (int) The max history to store.
    :return: (History) A History object.
    '''
    history = utils_common.History(max_history,
                                   weights=shape,
                                   gradients=shape,
                                   losses=())
    for _ in range(max_history):
        history.append(weights=npr.rand(*shape),
                       losses=npr.rand(),
                       gradients=npr.rand(*shape))
    return history
def test_history_len():
    '''Test History's __len__ method.'''
    max_history = 3
    test = {chr(97 + i): (i + 1, ) for i in range(26)}
    history = utils.History(max_history, **test)
    assert len(history) == len(test)
def test_history_getitem():
    '''Test History's __getindex__ method.'''
    max_history = 3
    history = utils.History(max_history, test1d=(5, ), test2d=(5, 5))
    assert history['test1d'].shape == (max_history, 5)
    assert history['test2d'].shape == (max_history, 5, 5)