def test_SigmoidScale(self): # T.random.seed(1234) x = torch.randn([2, 3, 4]) for pre_scale_bias in [None, 0., 1.5]: scale = SigmoidScale(**({ 'pre_scale_bias': pre_scale_bias } if pre_scale_bias is not None else {})) if pre_scale_bias is None: pre_scale_bias = 0. assert (f'pre_scale_bias={pre_scale_bias}' in repr(scale)) # scale = T.jit_compile(scale) for pre_scale in [ torch.randn([4]), torch.randn([3, 1]), torch.randn([2, 1, 1]), torch.randn([2, 3, 4]) ]: expected_y = x * torch.sigmoid(pre_scale + pre_scale_bias) expected_log_det = broadcast_to( log_sigmoid(pre_scale + pre_scale_bias), list(x.shape)) check_scale(self, scale, x, pre_scale, expected_y, expected_log_det)
def forward( self, input: Tensor, input_log_det: Optional[Tensor] = None, inverse: bool = False, compute_log_det: bool = True) -> Tuple[Tensor, Optional[Tensor]]: """ Transform `x` into `y` and compute the log-determinant of `f` at `x` (if `inverse` is False); or transform `y` into `x` and compute the log-determinant of `f^{-1}` at `y` (if `inverse` is True). Args: input: `x` (if `inverse` is False) or `y` (if `inverse` is True). input_log_det: The log-determinant of the previous layer. Will add the log-determinant of this layer to `input_log_det`, to obtain the output log-determinant. If no previous layer, will start from zero log-det. inverse: See above. compute_log_det: Whether or not to compute the log-determinant? Returns: The transformed tensor, and the summed log-determinant of the previous flow layer and this layer. """ if inverse: event_ndims = self.y_event_ndims else: event_ndims = self.x_event_ndims if input.dim() < event_ndims: raise ValueError( '`input` is required to be at least {}d, but the input shape ' 'is {}.'.format(event_ndims, list(input.shape))) input_shape = list(input.shape) log_det_shape = input_shape[:len(input_shape) - event_ndims] if input_log_det is not None: if list(input_log_det.shape) != log_det_shape: raise ValueError( 'The shape of `input_log_det` is not expected: ' 'expected to be {}, but got {}.'.format( log_det_shape, list(input_log_det.shape))) # compute the transformed output and log-det output, output_log_det = self._forward(input, input_log_det, inverse, compute_log_det) if output_log_det is not None: if output_log_det.dim() < len(log_det_shape): output_log_det = broadcast_to(output_log_det, log_det_shape) if list(output_log_det.shape) != log_det_shape: raise ValueError( 'The shape of `output_log_det` is not expected: ' 'expected to be {}, but got {}.'.format( log_det_shape, list(output_log_det.shape))) return output, output_log_det
def test_ExpScale(self): # T.random.seed(1234) x = torch.randn([2, 3, 4]) scale = ExpScale() # scale = T.jit_compile(scale) for pre_scale in [ torch.randn([4]), torch.randn([3, 1]), torch.randn([2, 1, 1]), torch.randn([2, 3, 4]) ]: expected_y = x * torch.exp(pre_scale) expected_log_det = broadcast_to(pre_scale, list(x.shape)) check_scale(self, scale, x, pre_scale, expected_y, expected_log_det)
def test_LinearScale(self): # T.random.seed(1234) x = torch.randn([2, 3, 4]) scale = LinearScale(epsilon=1e-5) assert ('epsilon=' in repr(scale)) # scale = T.jit_compile(scale) for pre_scale in [ torch.randn([4]), torch.randn([3, 1]), torch.randn([2, 1, 1]), torch.randn([2, 3, 4]) ]: expected_y = x * pre_scale expected_log_det = broadcast_to(torch.log(torch.abs(pre_scale)), list(x.shape)) check_scale(self, scale, x, pre_scale, expected_y, expected_log_det)
def forward(self, input: Tensor, pre_scale: Tensor, event_ndims: int, input_log_det: Optional[Tensor] = None, compute_log_det: bool = True, inverse: bool = False ) -> Tuple[Tensor, Optional[Tensor]]: # validate the argument if input.dim() < event_ndims: raise ValueError( '`rank(input) >= event_ndims` does not hold: the `input` shape ' 'is {}, while `event_ndims` is {}.'. format(list(input.shape), event_ndims) ) if pre_scale.dim() > input.dim(): raise ValueError( '`rank(input) >= rank(pre_scale)` does not hold: the `input` ' 'shape is {}, while the shape of `pre_scale` is {}.'. format(list(input.shape), list(pre_scale.shape)) ) input_shape = list((input.shape)) event_ndims_start = len(input_shape) - event_ndims event_shape = input_shape[event_ndims_start:] log_det_shape = input_shape[: event_ndims_start] if input_log_det is not None: if list(input_log_det.shape) != log_det_shape: raise ValueError( 'The shape of `input_log_det` is not expected: ' 'expected to be {}, but got {}'. format(log_det_shape, list(input_log_det.shape)) ) scale, log_scale = self._scale_and_log_scale( pre_scale, inverse, compute_log_det) output = input * scale # 注意是点乘! if log_scale is not None: log_scale = broadcast_to( log_scale, broadcast_shape(list(log_scale.shape), event_shape) ) # the last `event_ndims` dimensions must match the `event_shape` log_scale_shape = list(log_scale.shape) log_scale_event_shape = \ log_scale_shape[len(log_scale_shape) - event_ndims:] if log_scale_event_shape != event_shape: raise ValueError( 'The shape of the final {}d of `log_scale` is not expected: ' 'expected to be {}, but got {}.'. format(event_ndims, event_shape, log_scale_event_shape) ) # reduce the last `event_ndims` of log_scale log_scale = reduce_sum(log_scale, axis=list(range(-event_ndims, 0))) # now add to input_log_det, or broadcast `log_scale` to `log_det_shape` if input_log_det is not None: output_log_det = input_log_det + log_scale if list(output_log_det.shape) != log_det_shape: raise ValueError( 'The shape of the computed `output_log_det` is not expected: ' 'expected to be {}, but got {}.'. format(list(output_log_det.shape), log_det_shape) ) else: output_log_det = broadcast_to(log_scale, log_det_shape) else: output_log_det = None return output, output_log_det