class WEIGHTED_SIMILARITY: def __init__(self, primary_loss="mse", weights=[1.0, 1.0], asImages=True): self.main_loss = MSE_LOSS(reduction="mean") if "bce" in primary_loss: self.main_loss = BCE_LOSS(reduction="mean") self.ssim_loss = SSIM(data_range=1.0, nonnegative_ssim=True) self.weights = weights self.asImages = asImages def transform(self, x): # Gray Image to 3 channels in images if self.asImages and x.shape[-3] == 1: x = x.repeat(1, 3, 1, 1) # Gray frames to 3 channels in video if not self.asImages and x.shape[-4] == 1: x = x.repeat(1, 3, 1, 1, 1) # BatchSize x channels x frames x h x w if self.asImages: return x # transpose n_frame and channels for video else: return x.transpose(1, 2).flatten(start_dim=0, end_dim=1) def __call__(self, original, reconstructions): return ( self.weights[0] * self.main_loss(original, reconstructions)) + ( self.weights[1] * 100 * (1 - self.ssim_loss.forward(self.transform(original), self.transform(reconstructions))))
class WEIGHTED_SIMILARITY: def __init__(self, primary_loss="mse", weights=[1.0, 1.0]): self.main_loss = MSE_LOSS(reduction="mean") if "bce" in primary_loss: self.main_loss = BCE_LOSS(reduction="mean") self.ssim_loss = SSIM() self.weights = weights def __call__(self, original, reconstructions): return (self.weights[0] * self.main_loss(original, reconstructions) ) + (self.weights[1] * 100 * (1 - self.ssim_loss.forward(original, reconstructions)))