Exemplo n.º 1
0
    def test_SigmoidScale(self):
        # T.random.seed(1234)

        x = torch.randn([2, 3, 4])

        for pre_scale_bias in [None, 0., 1.5]:
            scale = SigmoidScale(**({
                'pre_scale_bias': pre_scale_bias
            } if pre_scale_bias is not None else {}))
            if pre_scale_bias is None:
                pre_scale_bias = 0.
            assert (f'pre_scale_bias={pre_scale_bias}' in repr(scale))
            # scale = T.jit_compile(scale)

            for pre_scale in [
                    torch.randn([4]),
                    torch.randn([3, 1]),
                    torch.randn([2, 1, 1]),
                    torch.randn([2, 3, 4])
            ]:
                expected_y = x * torch.sigmoid(pre_scale + pre_scale_bias)
                expected_log_det = broadcast_to(
                    log_sigmoid(pre_scale + pre_scale_bias), list(x.shape))
                check_scale(self, scale, x, pre_scale, expected_y,
                            expected_log_det)
Exemplo n.º 2
0
    def forward(
            self,
            input: Tensor,
            input_log_det: Optional[Tensor] = None,
            inverse: bool = False,
            compute_log_det: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
        """
        Transform `x` into `y` and compute the log-determinant of `f` at `x`
        (if `inverse` is False); or transform `y` into `x` and compute the
        log-determinant of `f^{-1}` at `y` (if `inverse` is True).

        Args:
            input: `x` (if `inverse` is False) or `y` (if `inverse` is True).
            input_log_det: The log-determinant of the previous layer.
                Will add the log-determinant of this layer to `input_log_det`,
                to obtain the output log-determinant.  If no previous layer,
                will start from zero log-det.
            inverse: See above.
            compute_log_det: Whether or not to compute the log-determinant?

        Returns:
            The transformed tensor, and the summed log-determinant of
            the previous flow layer and this layer.
        """
        if inverse:
            event_ndims = self.y_event_ndims
        else:
            event_ndims = self.x_event_ndims

        if input.dim() < event_ndims:
            raise ValueError(
                '`input` is required to be at least {}d, but the input shape '
                'is {}.'.format(event_ndims, list(input.shape)))

        input_shape = list(input.shape)
        log_det_shape = input_shape[:len(input_shape) - event_ndims]

        if input_log_det is not None:
            if list(input_log_det.shape) != log_det_shape:
                raise ValueError(
                    'The shape of `input_log_det` is not expected: '
                    'expected to be {}, but got {}.'.format(
                        log_det_shape, list(input_log_det.shape)))

        # compute the transformed output and log-det
        output, output_log_det = self._forward(input, input_log_det, inverse,
                                               compute_log_det)

        if output_log_det is not None:
            if output_log_det.dim() < len(log_det_shape):
                output_log_det = broadcast_to(output_log_det, log_det_shape)

            if list(output_log_det.shape) != log_det_shape:
                raise ValueError(
                    'The shape of `output_log_det` is not expected: '
                    'expected to be {}, but got {}.'.format(
                        log_det_shape, list(output_log_det.shape)))

        return output, output_log_det
Exemplo n.º 3
0
    def test_ExpScale(self):
        # T.random.seed(1234)

        x = torch.randn([2, 3, 4])
        scale = ExpScale()
        # scale = T.jit_compile(scale)

        for pre_scale in [
                torch.randn([4]),
                torch.randn([3, 1]),
                torch.randn([2, 1, 1]),
                torch.randn([2, 3, 4])
        ]:
            expected_y = x * torch.exp(pre_scale)
            expected_log_det = broadcast_to(pre_scale, list(x.shape))
            check_scale(self, scale, x, pre_scale, expected_y,
                        expected_log_det)
Exemplo n.º 4
0
    def test_LinearScale(self):
        # T.random.seed(1234)

        x = torch.randn([2, 3, 4])
        scale = LinearScale(epsilon=1e-5)
        assert ('epsilon=' in repr(scale))
        # scale = T.jit_compile(scale)

        for pre_scale in [
                torch.randn([4]),
                torch.randn([3, 1]),
                torch.randn([2, 1, 1]),
                torch.randn([2, 3, 4])
        ]:
            expected_y = x * pre_scale
            expected_log_det = broadcast_to(torch.log(torch.abs(pre_scale)),
                                            list(x.shape))
            check_scale(self, scale, x, pre_scale, expected_y,
                        expected_log_det)
Exemplo n.º 5
0
    def forward(self,
                input: Tensor,
                pre_scale: Tensor,
                event_ndims: int,
                input_log_det: Optional[Tensor] = None,
                compute_log_det: bool = True,
                inverse: bool = False
                ) -> Tuple[Tensor, Optional[Tensor]]:
        # validate the argument
        if input.dim() < event_ndims:
            raise ValueError(
                '`rank(input) >= event_ndims` does not hold: the `input` shape '
                'is {}, while `event_ndims` is {}.'.
                    format(list(input.shape), event_ndims)
            )
        if pre_scale.dim() > input.dim():
            raise ValueError(
                '`rank(input) >= rank(pre_scale)` does not hold: the `input` '
                'shape is {}, while the shape of `pre_scale` is {}.'.
                    format(list(input.shape), list(pre_scale.shape))
            )

        input_shape = list((input.shape))
        event_ndims_start = len(input_shape) - event_ndims
        event_shape = input_shape[event_ndims_start:]
        log_det_shape = input_shape[: event_ndims_start]

        if input_log_det is not None:
            if list(input_log_det.shape) != log_det_shape:
                raise ValueError(
                    'The shape of `input_log_det` is not expected: '
                    'expected to be {}, but got {}'.
                        format(log_det_shape, list(input_log_det.shape))
                )

        scale, log_scale = self._scale_and_log_scale(
            pre_scale, inverse, compute_log_det)
        output = input * scale # 注意是点乘!

        if log_scale is not None:
            log_scale = broadcast_to(
                log_scale,
                broadcast_shape(list(log_scale.shape), event_shape)
            )

            # the last `event_ndims` dimensions must match the `event_shape`
            log_scale_shape = list(log_scale.shape)
            log_scale_event_shape = \
                log_scale_shape[len(log_scale_shape) - event_ndims:]
            if log_scale_event_shape != event_shape:
                raise ValueError(
                    'The shape of the final {}d of `log_scale` is not expected: '
                    'expected to be {}, but got {}.'.
                        format(event_ndims, event_shape, log_scale_event_shape)
                )

            # reduce the last `event_ndims` of log_scale
            log_scale = reduce_sum(log_scale, axis=list(range(-event_ndims, 0)))

            # now add to input_log_det, or broadcast `log_scale` to `log_det_shape`
            if input_log_det is not None:
                output_log_det = input_log_det + log_scale
                if list(output_log_det.shape) != log_det_shape:
                    raise ValueError(
                        'The shape of the computed `output_log_det` is not expected: '
                        'expected to be {}, but got {}.'.
                            format(list(output_log_det.shape), log_det_shape)
                    )
            else:
                output_log_det = broadcast_to(log_scale, log_det_shape)
        else:
            output_log_det = None

        return output, output_log_det