コード例 #1
0
ファイル: test_comparison.py プロジェクト: pystiche/pystiche
    def test_call(self, encoder):
        torch.manual_seed(0)
        target_image = torch.rand(1, 3, 128, 128)
        input_image = torch.rand(1, 3, 128, 128)

        loss = loss_.GramLoss(encoder)
        loss.set_target_image(target_image)

        actual = loss(input_image)
        desired = F.mse_loss(
            pystiche.gram_matrix(encoder(input_image), normalize=loss.normalize),
            pystiche.gram_matrix(encoder(target_image), normalize=loss.normalize),
        )
        ptu.assert_allclose(actual, desired)
コード例 #2
0
def get_style_op(encoder, layer_weight):
    return loss.GramLoss(encoder, score_weight=layer_weight)
コード例 #3
0
ファイル: _loss.py プロジェクト: jbueltemeier/pystiche_papers
 def encoding_loss_fn(encoder: enc.Encoder,
                      layer_weight: float) -> loss.GramLoss:
     return loss.GramLoss(encoder, score_weight=layer_weight)
コード例 #4
0
            )) from error

    return mle_fn()


# This needs to be filled manually, since some losses such as MRF need more parameters
# than just encoder and score_weight
LOSSES = {
    "featurereconstruction": (
        "FeatureReconstruction",
        lambda encoder, score_weight: loss.FeatureReconstructionLoss(
            encoder, score_weight=score_weight),
    ),
    "gram": (
        "Gram",
        lambda encoder, score_weight: loss.GramLoss(encoder,
                                                    score_weight=score_weight),
    ),
    "mrf": (
        "MRF",
        lambda encoder, score_weight: loss.MRFLoss(
            encoder, patch_size=3, stride=2, score_weight=score_weight),
    ),
}


def make_loss(
    loss_str: str, layers_str: str, score_weight: float,
    mle: enc.MultiLayerEncoder
) -> Union[loss.ComparisonLoss, loss.MultiLayerEncodingLoss]:
    loss_str_normalized = loss_str.lower().replace("_", "").replace("-", "")
    if loss_str_normalized not in LOSSES.keys():
コード例 #5
0
ファイル: test_comparison.py プロジェクト: pystiche/pystiche
 def test_repr_smoke(self, encoder):
     assert isinstance(repr(loss_.GramLoss(encoder)), str)