def _log_prob_shape(dist, x_size=torch.Size()): event_dims = len(dist.event_shape) expected_shape = broadcast_shape(dist.shape(), x_size, strict=True) if event_dims > 0: expected_shape = expected_shape[:-event_dims] return expected_shape