def _forward(self, input: Tensor, input_log_det: Optional[Tensor], inverse: bool, compute_log_det: bool) -> Tuple[Tensor, Optional[Tensor]]: w = self.w b = self.b _ = self.get_uhat() input_shape = list(input.shape) if input_shape[self.axis] != self.num_features: raise ValueError( '`num_features` is not equal to the `axis` of `input`, ' 'please check that!') if inverse == True: raise ValueError('`inverse` for planar nf should never be `True`') x_flatten, front_shape = flatten_to_ndims(input, 2) wxb = torch.matmul(x_flatten, w.T) + b tanh_wb = torch.tanh(wxb) out = x_flatten + self.u_hat * tanh_wb out = unflatten_from_ndims(out, front_shape=front_shape) output_log_det = input_log_det if compute_log_det: grad = 1. - tanh_wb**2 phi = grad * w # shape == [?, n_units] u_phi = torch.matmul(phi, self.u_hat.T) log_det = torch.log(torch.abs(1. + u_phi)) # [? 1] log_det = unflatten_from_ndims(log_det, front_shape) log_det = torch.squeeze(log_det) if output_log_det is None: output_log_det = log_det else: output_log_det += log_det return out, output_log_det
def _forward(self, input: Tensor, input_log_det: Optional[Tensor], inverse: bool, compute_log_det: bool) -> Tuple[Tensor, Optional[Tensor]]: # obtain the weight weight, log_det = self.invertible_matrix( inverse=inverse, compute_log_det=compute_log_det) spatial_ndims = self.x_event_ndims - 1 weight = weight.reshape(list(weight.shape) + [1] * spatial_ndims) # compute the output output, front_shape = flatten_to_ndims(input, spatial_ndims + 2) output = self._linear_transform(output, weight) output = unflatten_from_ndims(output, front_shape) # compute the log_det output_log_det = input_log_det if log_det is not None: for axis in list(range(-spatial_ndims, 0)): log_det = log_det * float(input.shape[axis]) if input_log_det is not None: output_log_det = input_log_det + log_det else: output_log_det = log_det return output, output_log_det
def test_planar(self): input1 = torch.randn(12, 5) input2 = torch.randn(3, 4, 5) model1 = Planar(num_features=5, ) model2 = Planar(num_features=5, event_ndims=1) model3 = Planar(num_features=5, event_ndims=2) input_log_det1 = torch.randn(12) input_log_det2 = torch.randn(3, 4) x_flatten, front_shape = flatten_to_ndims(input1, 2) wxb = torch.matmul(x_flatten, model1.w.T) + model1.b tanh_wb = torch.tanh(wxb) out = x_flatten + model1.get_uhat() * tanh_wb expected_y = unflatten_from_ndims(out, front_shape=front_shape) grad = 1. - tanh_wb**2 phi = grad * model1.w # shape == [?, n_units] u_phi = torch.matmul(phi, model1.get_uhat().T) log_det = torch.log(torch.abs(1. + u_phi)) # [? 1] expected_log_det = unflatten_from_ndims(log_det, front_shape) expected_log_det = expected_log_det.squeeze() print('expect', expected_log_det.shape) noninvert_flow_standard_check(self, model1, input1, expected_y, expected_log_det, input_log_det1) x_flatten, front_shape = flatten_to_ndims(input2, 2) wxb = torch.matmul(x_flatten, model2.w.T) + model2.b tanh_wb = torch.tanh(wxb) out = x_flatten + model2.get_uhat() * tanh_wb expected_y = unflatten_from_ndims(out, front_shape=front_shape) grad = 1. - tanh_wb**2 phi = grad * model2.w # shape == [?, n_units] u_phi = torch.matmul(phi, model2.get_uhat().T) log_det = torch.log(torch.abs(1. + u_phi)) # [? 1] expected_log_det = unflatten_from_ndims(log_det, front_shape) expected_log_det = expected_log_det.squeeze() print('expect', expected_log_det.shape) noninvert_flow_standard_check(self, model2, input2, expected_y, expected_log_det, input_log_det2)
def check_invertible_linear( ctx, spatial_ndims: int, invertible_linear_factory, linear_factory, strict: bool, ): for batch_shape in ([2], [2, 3]): num_features = 4 spatial_shape = [5, 6, 7][:spatial_ndims] x = torch.randn( list(batch_shape) + [num_features] + list(spatial_shape)) # construct the layer flow = invertible_linear_factory(num_features, strict=strict) assert (f'num_features={num_features}' in repr(flow)) # derive the expected answer weight, log_det = flow.invertible_matrix(inverse=False, compute_log_det=True) linear_kwargs = {} if spatial_ndims > 0: linear_kwargs['kernel_size'] = 1 linear = linear_factory(num_features, num_features, weight_init=torch.reshape( weight, list(weight.shape) + [1] * spatial_ndims), use_bias=False, **linear_kwargs) x_flatten, front_shape = flatten_to_ndims(x, spatial_ndims + 2) expected_y = unflatten_from_ndims(linear(x_flatten), front_shape) expected_log_det = reduce_sum( log_det.expand(spatial_shape)).expand(batch_shape) # check the invertible layer flow_standard_check(ctx, flow, x, expected_y, expected_log_det, torch.randn(list(batch_shape)))
def _forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: output, front_shape = flatten_to_ndims(input, 2) output = torch.nn.functional.linear(output, weight, bias) output = unflatten_from_ndims(output, front_shape) return output