def __init__(self, K, scale_table=None, mean_table=None, weight_table=None, *args, scale_bound=0.11, tail_mass=1e-9, **kwargs): super().__init__(*args, **kwargs) #ywz for mixture numbers:K self.K = K if scale_table and \ (scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table)): raise ValueError(f'Invalid scale_table "({scale_table})"') self.register_buffer( 'scale_table', self._prepare_scale_table(scale_table) if scale_table else torch.Tensor()) self.register_buffer( 'scale_bound', torch.Tensor([float(scale_bound)]) if scale_bound is not None else None) self.tail_mass = float(tail_mass) if scale_bound is None and scale_table: self.lower_bound_scale = LowerBound(self.scale_table[0]) elif scale_bound > 0: self.lower_bound_scale = LowerBound(scale_bound) else: raise ValueError('Invalid parameters')
def __init__(self, scale_table, *args, scale_bound=0.11, tail_mass=1e-9, **kwargs): super().__init__(*args, **kwargs) if not isinstance(scale_table, (type(None), list, tuple)): raise ValueError(f'Invalid type for scale_table "{type(scale_table)}"') if isinstance(scale_table, (list, tuple)) and len(scale_table) < 1: raise ValueError(f'Invalid scale_table length "{len(scale_table)}"') if scale_table and ( scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table) ): raise ValueError(f'Invalid scale_table "({scale_table})"') self.tail_mass = float(tail_mass) if scale_bound is None and scale_table: self.lower_bound_scale = LowerBound(self.scale_table[0]) elif scale_bound > 0: self.lower_bound_scale = LowerBound(scale_bound) else: raise ValueError("Invalid parameters") self.register_buffer( "scale_table", self._prepare_scale_table(scale_table) if scale_table else torch.Tensor(), ) self.register_buffer( "scale_bound", torch.Tensor([float(scale_bound)]) if scale_bound is not None else None, )
def test_lower_bound_grads(self): x = torch.rand(16, requires_grad=True) bound = torch.rand(1) lower_bound = LowerBound(bound) y = lower_bound(x) y.backward(x) assert x.grad is not None assert (x.grad == ((x >= bound) * x)).all()
def test_lower_bound_ok(self): x = torch.rand(16) bound = torch.rand(1) lower_bound = LowerBound(bound) assert (lower_bound(x) == torch.max(x, bound)).all() scripted = torch.jit.script(lower_bound) assert (scripted(x) == torch.max(x, bound)).all()
def __init__( self, likelihood_bound=1e-9, entropy_coder=None, entropy_coder_precision=16 ): super().__init__() if entropy_coder is None: entropy_coder = default_entropy_coder() self.entropy_coder = _EntropyCoder(entropy_coder) self.entropy_coder_precision = int(entropy_coder_precision) self.use_likelihood_bound = likelihood_bound > 0 if self.use_likelihood_bound: self.likelihood_lower_bound = LowerBound(likelihood_bound) # to be filled on update() self.register_buffer("_offset", torch.IntTensor()) self.register_buffer("_quantized_cdf", torch.IntTensor()) self.register_buffer("_cdf_length", torch.IntTensor())