def test_generate_pastiche():
    _, content = st.load_images()
    pastiche = st.generate_pastiche(content)

    assert isinstance(pastiche, Variable), 'pastice must be of type Variable'
    assert pastiche.size() == content.size()
    assert np.allclose(content.data.numpy(), pastiche.data.numpy())
    assert content is not pastiche
    assert content.data is not pastiche.data
def test_construct_style_loss_fns():
    vgg_model = utils.load_vgg()
    style_image, content_image = st.load_images()
    pastiche = st.generate_pastiche(content_image)
    style_layers = ['r11', 'r21', 'r31', 'r41', 'r51']

    out = vgg_model(pastiche, style_layers)
    loss_fns = st.construct_style_loss_fns(vgg_model, style_image,
                                           style_layers)
    assert len(loss_fns) == len(style_layers)
    assert all([isinstance(loss_fn, st.StyleLoss) for loss_fn in loss_fns])
    losses = [loss_fn(A).data[0] for loss_fn, A in zip(loss_fns, out)]
    expected = [
        95157.6953125, 8318182.5, 4280054.5, 213536288.0, 26124.064453125
    ]
    assert np.allclose(
        losses, expected), 'Expected: %s, Actual: %s' % (expected, losses)
def test_construct_content_loss_fns():
    torch.manual_seed(0)
    vgg_model = utils.load_vgg()
    style_image, content_image = st.load_images()
    pastiche = st.generate_pastiche(content_image) + Variable(
        torch.randn(*content_image.size()))
    content_layers = ['r42']

    out = vgg_model(pastiche, content_layers)
    loss_fns = st.construct_content_loss_fns(vgg_model, content_image,
                                             content_layers)
    assert len(loss_fns) == len(content_layers)
    assert all([isinstance(loss_fn, st.ContentLoss) for loss_fn in loss_fns])

    losses = [loss_fn(A).data[0] for loss_fn, A in zip(loss_fns, out)]
    expected = [324.38946533203125]
    assert np.allclose(
        losses, expected), 'Expected: %s, Actual: %s' % (expected, losses)
def test_load_images():
    style, content = st.load_images()

    assert isinstance(style,
                      torch.FloatTensor), 'style_image must be FloatTensor'
    assert isinstance(content,
                      torch.FloatTensor), 'content_image must be FloatTensor'

    style_test = np.load('test_data/style_image_test.npy')
    content_test = np.load('test_data/content_image_test.npy')

    style_test, content_test = torch.FloatTensor(
        style_test), torch.FloatTensor(content_test)

    assert style_test.size() == style.size(), 'Expected: %s, Actual: %s' % (
        style_test.size(), style.size())
    assert content_test.size() == content.size(
    ), 'Expected: %s, Actual: %s' % (content_test.size(), content.size())

    assert np.allclose(style_test.numpy(), style.numpy())
    assert np.allclose(content_test.numpy(), content.numpy())
def test_extract_vgg_features():
    style, content = st.load_images()
    st.extract_vgg_features()
def test_generate_pastiche():
    _, content = st.load_images()
    pastiche = st.generate_pastiche(content)
    assert pastiche.size() == content.size()