def test_generate_pastiche(): _, content = st.load_images() pastiche = st.generate_pastiche(content) assert isinstance(pastiche, Variable), 'pastice must be of type Variable' assert pastiche.size() == content.size() assert np.allclose(content.data.numpy(), pastiche.data.numpy()) assert content is not pastiche assert content.data is not pastiche.data
def test_construct_style_loss_fns(): vgg_model = utils.load_vgg() style_image, content_image = st.load_images() pastiche = st.generate_pastiche(content_image) style_layers = ['r11', 'r21', 'r31', 'r41', 'r51'] out = vgg_model(pastiche, style_layers) loss_fns = st.construct_style_loss_fns(vgg_model, style_image, style_layers) assert len(loss_fns) == len(style_layers) assert all([isinstance(loss_fn, st.StyleLoss) for loss_fn in loss_fns]) losses = [loss_fn(A).data[0] for loss_fn, A in zip(loss_fns, out)] expected = [ 95157.6953125, 8318182.5, 4280054.5, 213536288.0, 26124.064453125 ] assert np.allclose( losses, expected), 'Expected: %s, Actual: %s' % (expected, losses)
def test_construct_content_loss_fns(): torch.manual_seed(0) vgg_model = utils.load_vgg() style_image, content_image = st.load_images() pastiche = st.generate_pastiche(content_image) + Variable( torch.randn(*content_image.size())) content_layers = ['r42'] out = vgg_model(pastiche, content_layers) loss_fns = st.construct_content_loss_fns(vgg_model, content_image, content_layers) assert len(loss_fns) == len(content_layers) assert all([isinstance(loss_fn, st.ContentLoss) for loss_fn in loss_fns]) losses = [loss_fn(A).data[0] for loss_fn, A in zip(loss_fns, out)] expected = [324.38946533203125] assert np.allclose( losses, expected), 'Expected: %s, Actual: %s' % (expected, losses)
def test_load_images(): style, content = st.load_images() assert isinstance(style, torch.FloatTensor), 'style_image must be FloatTensor' assert isinstance(content, torch.FloatTensor), 'content_image must be FloatTensor' style_test = np.load('test_data/style_image_test.npy') content_test = np.load('test_data/content_image_test.npy') style_test, content_test = torch.FloatTensor( style_test), torch.FloatTensor(content_test) assert style_test.size() == style.size(), 'Expected: %s, Actual: %s' % ( style_test.size(), style.size()) assert content_test.size() == content.size( ), 'Expected: %s, Actual: %s' % (content_test.size(), content.size()) assert np.allclose(style_test.numpy(), style.numpy()) assert np.allclose(content_test.numpy(), content.numpy())
def test_extract_vgg_features(): style, content = st.load_images() st.extract_vgg_features()
def test_generate_pastiche(): _, content = st.load_images() pastiche = st.generate_pastiche(content) assert pastiche.size() == content.size()