Exemplo n.º 1
0
    def _forward(self, input: Tensor, input_log_det: Optional[Tensor],
                 inverse: bool,
                 compute_log_det: bool) -> Tuple[Tensor, Optional[Tensor]]:
        # obtain the weight
        weight, log_det = self.invertible_matrix(
            inverse=inverse, compute_log_det=compute_log_det)
        spatial_ndims = self.x_event_ndims - 1
        weight = weight.reshape(list(weight.shape) + [1] * spatial_ndims)

        # compute the output
        output, front_shape = flatten_to_ndims(input, spatial_ndims + 2)
        output = self._linear_transform(output, weight)
        output = unflatten_from_ndims(output, front_shape)

        # compute the log_det
        output_log_det = input_log_det
        if log_det is not None:
            for axis in list(range(-spatial_ndims, 0)):
                log_det = log_det * float(input.shape[axis])
            if input_log_det is not None:
                output_log_det = input_log_det + log_det
            else:
                output_log_det = log_det

        return output, output_log_det
Exemplo n.º 2
0
    def _forward(self, input: Tensor, input_log_det: Optional[Tensor],
                 inverse: bool,
                 compute_log_det: bool) -> Tuple[Tensor, Optional[Tensor]]:

        w = self.w
        b = self.b
        _ = self.get_uhat()
        input_shape = list(input.shape)
        if input_shape[self.axis] != self.num_features:
            raise ValueError(
                '`num_features` is not equal to the `axis` of `input`, '
                'please check that!')
        if inverse == True:
            raise ValueError('`inverse` for planar nf should never be `True`')
        x_flatten, front_shape = flatten_to_ndims(input, 2)
        wxb = torch.matmul(x_flatten, w.T) + b
        tanh_wb = torch.tanh(wxb)
        out = x_flatten + self.u_hat * tanh_wb
        out = unflatten_from_ndims(out, front_shape=front_shape)

        output_log_det = input_log_det
        if compute_log_det:
            grad = 1. - tanh_wb**2
            phi = grad * w  # shape == [?, n_units]
            u_phi = torch.matmul(phi, self.u_hat.T)
            log_det = torch.log(torch.abs(1. + u_phi))  # [? 1]
            log_det = unflatten_from_ndims(log_det, front_shape)
            log_det = torch.squeeze(log_det)
            if output_log_det is None:
                output_log_det = log_det
            else:
                output_log_det += log_det

        return out, output_log_det
Exemplo n.º 3
0
    def test_planar(self):
        input1 = torch.randn(12, 5)
        input2 = torch.randn(3, 4, 5)
        model1 = Planar(num_features=5, )
        model2 = Planar(num_features=5, event_ndims=1)
        model3 = Planar(num_features=5, event_ndims=2)
        input_log_det1 = torch.randn(12)
        input_log_det2 = torch.randn(3, 4)

        x_flatten, front_shape = flatten_to_ndims(input1, 2)
        wxb = torch.matmul(x_flatten, model1.w.T) + model1.b
        tanh_wb = torch.tanh(wxb)
        out = x_flatten + model1.get_uhat() * tanh_wb
        expected_y = unflatten_from_ndims(out, front_shape=front_shape)
        grad = 1. - tanh_wb**2
        phi = grad * model1.w  # shape == [?, n_units]
        u_phi = torch.matmul(phi, model1.get_uhat().T)
        log_det = torch.log(torch.abs(1. + u_phi))  # [? 1]
        expected_log_det = unflatten_from_ndims(log_det, front_shape)
        expected_log_det = expected_log_det.squeeze()
        print('expect', expected_log_det.shape)

        noninvert_flow_standard_check(self, model1, input1, expected_y,
                                      expected_log_det, input_log_det1)

        x_flatten, front_shape = flatten_to_ndims(input2, 2)
        wxb = torch.matmul(x_flatten, model2.w.T) + model2.b
        tanh_wb = torch.tanh(wxb)
        out = x_flatten + model2.get_uhat() * tanh_wb
        expected_y = unflatten_from_ndims(out, front_shape=front_shape)
        grad = 1. - tanh_wb**2
        phi = grad * model2.w  # shape == [?, n_units]
        u_phi = torch.matmul(phi, model2.get_uhat().T)
        log_det = torch.log(torch.abs(1. + u_phi))  # [? 1]
        expected_log_det = unflatten_from_ndims(log_det, front_shape)
        expected_log_det = expected_log_det.squeeze()
        print('expect', expected_log_det.shape)

        noninvert_flow_standard_check(self, model2, input2, expected_y,
                                      expected_log_det, input_log_det2)
Exemplo n.º 4
0
def check_invertible_linear(
    ctx,
    spatial_ndims: int,
    invertible_linear_factory,
    linear_factory,
    strict: bool,
):
    for batch_shape in ([2], [2, 3]):
        num_features = 4
        spatial_shape = [5, 6, 7][:spatial_ndims]
        x = torch.randn(
            list(batch_shape) + [num_features] + list(spatial_shape))

        # construct the layer
        flow = invertible_linear_factory(num_features, strict=strict)
        assert (f'num_features={num_features}' in repr(flow))

        # derive the expected answer
        weight, log_det = flow.invertible_matrix(inverse=False,
                                                 compute_log_det=True)
        linear_kwargs = {}
        if spatial_ndims > 0:
            linear_kwargs['kernel_size'] = 1
        linear = linear_factory(num_features,
                                num_features,
                                weight_init=torch.reshape(
                                    weight,
                                    list(weight.shape) + [1] * spatial_ndims),
                                use_bias=False,
                                **linear_kwargs)
        x_flatten, front_shape = flatten_to_ndims(x, spatial_ndims + 2)
        expected_y = unflatten_from_ndims(linear(x_flatten), front_shape)
        expected_log_det = reduce_sum(
            log_det.expand(spatial_shape)).expand(batch_shape)

        # check the invertible layer
        flow_standard_check(ctx, flow, x, expected_y, expected_log_det,
                            torch.randn(list(batch_shape)))
Exemplo n.º 5
0
 def _forward(self, input: Tensor, weight: Tensor,
              bias: Optional[Tensor]) -> Tensor:
     output, front_shape = flatten_to_ndims(input, 2)
     output = torch.nn.functional.linear(output, weight, bias)
     output = unflatten_from_ndims(output, front_shape)
     return output