Пример #1
0
 def get_value(self, obj, distribution_mode=None, device=None, **kwargs):
     slicer = [slice(None)] * (len(obj.shape) - 1)
     reg = w.create_variable(0., device=device)
     if self.unknown_type == 'delta_beta':
         o1 = obj[slicer + [0]]
         o2 = obj[slicer + [1]]
         axis_offset = 0 if distribution_mode is None else 1
         reg = reg + self.gamma * total_variation_3d(o1, axis_offset=axis_offset)
         reg = reg + self.gamma * total_variation_3d(o2, axis_offset=axis_offset)
     elif self.unknown_type == 'real_imag':
         r = obj[slicer + [0]]
         i = obj[slicer + [1]]
         axis_offset = 0 if distribution_mode is None else 1
         reg = reg + self.gamma * total_variation_3d(r ** 2 + i ** 2, axis_offset=axis_offset)
         reg = reg + self.gamma * total_variation_3d(w.arctan2(i, r), axis_offset=axis_offset)
     return reg
Пример #2
0
 def get_value(self, obj, device=None, **kwargs):
     slicer = [slice(None)] * (len(obj.shape) - 1)
     reg = w.create_variable(0., device=device)
     if self.unknown_type == 'delta_beta':
         if self.alpha_d not in [None, 0]:
             reg = reg + self.alpha_d * w.mean(w.abs(obj[slicer + [0]]))
         if self.alpha_b not in [None, 0]:
             reg = reg + self.alpha_b * w.mean(w.abs(obj[slicer + [1]]))
     elif self.unknown_type == 'real_imag':
         r = obj[slicer + [0]]
         i = obj[slicer + [1]]
         if self.alpha_d not in [None, 0]:
             om = w.sqrt(r ** 2 + i ** 2)
             reg = reg + self.alpha_d * w.mean(w.abs(om - w.mean(om)))
         if self.alpha_b not in [None, 0]:
             reg = reg + self.alpha_b * w.mean(w.abs(w.arctan2(i, r)))
     return reg
Пример #3
0
    def get_value(self, obj, distribution_mode=None, device=None, **kwargs):
        slicer = [slice(None)] * (len(obj.shape) - 1)
        reg = w.create_variable(0., device=device)
        if self.unknown_type == 'delta_beta':
            o1 = obj[slicer + [0]]
            o2 = obj[slicer + [1]]
        elif self.unknown_type == 'real_imag':
            r = obj[slicer + [0]]
            i = obj[slicer + [1]]
            o1 = w.sqrt(r ** 2 + i ** 2)
            o2 = w.arctan2(i, r)
        else:
            raise ValueError('Invalid value for unknown_type.')

        reg = reg + self.gamma * w.pcc(o1)
        reg = reg + self.gamma * w.pcc(o2)
        return reg