def test_call_guided(self, encoder): patch_size = 2 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) target_guide = torch.cat( (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2 ) input_guide = target_guide.flip(2) loss = loss_.MRFLoss(encoder, patch_size, stride=stride) loss.set_target_image(target_image, guide=target_guide) loss.set_input_guide(input_guide) actual = loss(input_image) input_enc = encoder(input_image)[:, :, :16, :] target_enc = encoder(target_image)[:, :, 16:, :] desired = F.mrf_loss( pystiche.extract_patches2d(input_enc, patch_size, stride=stride), pystiche.extract_patches2d(target_enc, patch_size, stride=stride), batched_input=True, ) ptu.assert_allclose(actual, desired)
def test_call(self, encoder): patch_size = 3 stride = 2 torch.manual_seed(0) target_image = torch.rand(1, 3, 32, 32) input_image = torch.rand(1, 3, 32, 32) loss = loss_.MRFLoss(encoder, patch_size, stride=stride) loss.set_target_image(target_image) actual = loss(input_image) desired = F.mrf_loss( pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride), pystiche.extract_patches2d( encoder(target_image), patch_size, stride=stride ), batched_input=True, ) ptu.assert_allclose(actual, desired)
def get_style_op(encoder, layer_weight): return loss.MRFLoss(encoder, patch_size=3, stride=2, score_weight=layer_weight)
# This needs to be filled manually, since some losses such as MRF need more parameters # than just encoder and score_weight LOSSES = { "featurereconstruction": ( "FeatureReconstruction", lambda encoder, score_weight: loss.FeatureReconstructionLoss( encoder, score_weight=score_weight), ), "gram": ( "Gram", lambda encoder, score_weight: loss.GramLoss(encoder, score_weight=score_weight), ), "mrf": ( "MRF", lambda encoder, score_weight: loss.MRFLoss( encoder, patch_size=3, stride=2, score_weight=score_weight), ), } def make_loss( loss_str: str, layers_str: str, score_weight: float, mle: enc.MultiLayerEncoder ) -> Union[loss.ComparisonLoss, loss.MultiLayerEncodingLoss]: loss_str_normalized = loss_str.lower().replace("_", "").replace("-", "") if loss_str_normalized not in LOSSES.keys(): raise ValueError( add_suggestion( f"Unknown loss '{loss_str}'.", word=loss_str_normalized, possibilities=tuple(zip(*LOSSES.values()))[0],
def test_repr_smoke(self, encoder): assert isinstance(repr(loss_.MRFLoss(encoder, 2, stride=1)), str)