예제 #1
0
    def test_call_guided(self, encoder):
        patch_size = 2
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)
        target_guide = torch.cat(
            (torch.zeros(1, 1, 16, 32), torch.ones(1, 1, 16, 32)), dim=2
        )
        input_guide = target_guide.flip(2)

        loss = loss_.MRFLoss(encoder, patch_size, stride=stride)
        loss.set_target_image(target_image, guide=target_guide)
        loss.set_input_guide(input_guide)

        actual = loss(input_image)

        input_enc = encoder(input_image)[:, :, :16, :]
        target_enc = encoder(target_image)[:, :, 16:, :]
        desired = F.mrf_loss(
            pystiche.extract_patches2d(input_enc, patch_size, stride=stride),
            pystiche.extract_patches2d(target_enc, patch_size, stride=stride),
            batched_input=True,
        )
        ptu.assert_allclose(actual, desired)
예제 #2
0
    def test_call(self, encoder):
        patch_size = 3
        stride = 2

        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 32, 32)
        input_image = torch.rand(1, 3, 32, 32)

        loss = loss_.MRFLoss(encoder, patch_size, stride=stride)
        loss.set_target_image(target_image)

        actual = loss(input_image)
        desired = F.mrf_loss(
            pystiche.extract_patches2d(encoder(input_image), patch_size, stride=stride),
            pystiche.extract_patches2d(
                encoder(target_image), patch_size, stride=stride
            ),
            batched_input=True,
        )
        ptu.assert_allclose(actual, desired)
def get_style_op(encoder, layer_weight):
    return loss.MRFLoss(encoder,
                        patch_size=3,
                        stride=2,
                        score_weight=layer_weight)
예제 #4
0
# This needs to be filled manually, since some losses such as MRF need more parameters
# than just encoder and score_weight
LOSSES = {
    "featurereconstruction": (
        "FeatureReconstruction",
        lambda encoder, score_weight: loss.FeatureReconstructionLoss(
            encoder, score_weight=score_weight),
    ),
    "gram": (
        "Gram",
        lambda encoder, score_weight: loss.GramLoss(encoder,
                                                    score_weight=score_weight),
    ),
    "mrf": (
        "MRF",
        lambda encoder, score_weight: loss.MRFLoss(
            encoder, patch_size=3, stride=2, score_weight=score_weight),
    ),
}


def make_loss(
    loss_str: str, layers_str: str, score_weight: float,
    mle: enc.MultiLayerEncoder
) -> Union[loss.ComparisonLoss, loss.MultiLayerEncodingLoss]:
    loss_str_normalized = loss_str.lower().replace("_", "").replace("-", "")
    if loss_str_normalized not in LOSSES.keys():
        raise ValueError(
            add_suggestion(
                f"Unknown loss '{loss_str}'.",
                word=loss_str_normalized,
                possibilities=tuple(zip(*LOSSES.values()))[0],
예제 #5
0
 def test_repr_smoke(self, encoder):
     assert isinstance(repr(loss_.MRFLoss(encoder, 2, stride=1)), str)