示例#1
0
    def get_loss_functions(self, mask):
        """ Set the loss function """
        loss_funcs = []
        largest_face = self.largest_output
        loss_dict = {
            'mae': losses.mean_absolute_error,
            'mse': losses.mean_squared_error,
            'logcosh': losses.logcosh,
            'smooth_l1': generalized_loss,
            'l_inf_norm': l_inf_norm,
            'ssim': DSSIMObjective(),
            'gmsd': gmsd_loss,
            'pixel_gradient_diff': gradient_loss
        }
        img_loss_config = self.config.get("loss_function", "mae")
        mask_loss_config = "mse"

        for idx, loss_name in enumerate(self.names):
            if loss_name.startswith("mask"):
                loss_funcs.append(loss_dict[mask_loss_config])
                logger.debug("mask loss: %s", mask_loss_config)
            elif mask and idx == largest_face and self.config.get(
                    "penalized_mask_loss", False):
                loss_funcs.append(
                    PenalizedLoss(mask[0], loss_dict[img_loss_config]))
                logger.debug("final face loss: %s", img_loss_config)
            else:
                loss_funcs.append(loss_dict[img_loss_config])
                logger.debug("face loss func: %s", img_loss_config)
        logger.debug(loss_funcs)
        return loss_funcs
示例#2
0
文件: _base.py 项目: w7yuu/faceswap
    def get_loss_functions(self, side, predict, mask):
        """ Set the loss function """
        loss_funcs = list()
        largest_face = self.largest_output

        if self.config.get("dssim_loss", False):
            if not predict and side.lower() == "a":
                logger.verbose("Using DSSIM Loss")
            loss_func = DSSIMObjective()
        else:
            loss_func = losses.mean_absolute_error
            if not predict and side.lower() == "a":
                logger.verbose("Using Mean Absolute Error Loss")

        for idx, loss_name in enumerate(self.names):
            if loss_name.startswith("mask"):
                mask_func = losses.mean_squared_error
                loss_funcs.append(mask_func)
                logger.debug("mask loss: %s", mask_func)
            elif mask and idx == largest_face and self.config.get(
                    "penalized_mask_loss", False):
                face_func = PenalizedLoss(mask[0], loss_func)
                logger.debug("final face loss: %s", face_func)
                loss_funcs.append(face_func)
                if not predict and side.lower() == "a":
                    logger.verbose("Penalizing mask for Loss")
            else:
                logger.debug("face loss func: %s", loss_func)
                loss_funcs.append(loss_func)
        logger.debug(loss_funcs)
        return loss_funcs
示例#3
0
    def mask_loss_function(self, mask, side, initialize):
        """ Set the loss function for masks
            Side is input so we only log once """
        if side == "a" and not self.predict:
            logger.verbose("Using Mean Absolute Error Loss for mask")
        mask_loss_func = losses.mean_absolute_error

        if self.config.get("penalized_mask_loss", False):
            if side == "a" and not self.predict and initialize:
                logger.verbose("Using Penalized Loss for mask")
            mask_loss_func = PenalizedLoss(mask, mask_loss_func)
        logger.debug(mask_loss_func)
        return mask_loss_func
示例#4
0
 def loss_function(self, mask, side, initialize):
     """ Set the loss function
         Side is input so we only log once """
     if side == '1':
         logger.verbose(
             "Loss function for the model is mean_absolute_error")
     loss_func = losses.mean_absolute_error
     if mask and self.config.get("penalized_mask_loss", False):
         loss_mask = mask[0]
         if side == "1" and not self.predict and initialize:
             logger.verbose("Penalizing mask for Loss")
         loss_func = PenalizedLoss(loss_mask, loss_func)
     return loss_func
示例#5
0
    def loss_function(self, mask, side, initialize):
        """ Set the loss function
            Side is input so we only log once """
        if self.config.get("dssim_loss", False):
            if side == "a" and not self.predict and initialize:
                logger.verbose("Using DSSIM Loss")
            loss_func = DSSIMObjective()
        else:
            if side == "a" and not self.predict and initialize:
                logger.verbose("Using Mean Absolute Error Loss")
            loss_func = losses.mean_absolute_error

        if mask and self.config.get("penalized_mask_loss", False):
            loss_mask = mask[0]
            if side == "a" and not self.predict and initialize:
                logger.verbose("Penalizing mask for Loss")
            loss_func = PenalizedLoss(loss_mask, loss_func)
        return loss_func
示例#6
0
 def get_loss_functions(self):
     """ Set the loss function """
     loss_funcs = []
     for idx, loss_name in enumerate(self.names):
         if loss_name.startswith("mask"):
             loss_funcs.append(self.selected_mask_loss)
         elif self.mask_input is not None and self.config.get("penalized_mask_loss", False):
             face_size = self.output_shapes[idx][1]
             mask_size = self.mask_shape[1]
             scaling = face_size / mask_size
             logger.debug("face_size: %s mask_size: %s, mask_scaling: %s",
                          face_size, mask_size, scaling)
             loss_funcs.append(PenalizedLoss(self.mask_input, self.selected_loss,
                                             mask_scaling=scaling,
                                             preprocessing_func=self.mask_preprocessing_func))
         else:
             loss_funcs.append(self.selected_loss)
         logger.debug("%s: %s", loss_name, loss_funcs[-1])
     logger.debug(loss_funcs)
     return loss_funcs