NN
Functional
Functional interface for hyperbolic neural network operations.
This module provides stateless functions for hyperbolic computations, following PyTorch’s functional API pattern (similar to torch.nn.functional).
Functions
- compute_hyperbolic_mlr_logits
Compute logits for hyperbolic multinomial logistic regression.
- compute_hyperbolic_mlr_logits(x, weights, class_points, manifold)[source]
Compute logits for hyperbolic multinomial logistic regression (MLR).
This function implements the hyperbolic generalization of softmax/MLR, computing class logits for input points based on their hyperbolic distances to learned class representatives in the Poincaré ball.
- Parameters:
x (torch.Tensor) – Input points on the Poincaré ball. Shape (batch_size, dim).
weights (torch.Tensor) – Weight vectors (a-values) for each class, scaled by conformal factor. Shape (n_classes, dim).
class_points (torch.Tensor) – Class representatives (p-values) on the Poincaré ball. Shape (n_classes, dim).
manifold (MobiusManifold) – The hyperbolic manifold instance. Currently only PoincareBall is supported.
- Returns:
Logits for each input point and class. Shape (batch_size, n_classes). Can be passed to standard softmax for classification probabilities.
- Return type:
torch.Tensor
- Raises:
NotImplementedError – If manifold is not an instance of PoincareBall.
Notes
The hyperbolic MLR generalizes logistic regression to hyperbolic space. For each class k with representative \(p_k\) and weights \(a_k\), the logit for an input point \(x\) is:
\[\text{logit}_k(x) = \frac{\lambda_{p_k}^c \|a_k\|}{\sqrt{c}} \sinh^{-1}\left(\frac{2\sqrt{c} \langle a_k, -p_k \oplus_c x \rangle} {(1 - c\|-p_k \oplus_c x\|^2)\|a_k\|}\right)\]where: - \(\lambda_{p_k}^c = \frac{2}{1 - c\|p_k\|^2}\) is the conformal factor - \(\oplus_c\) denotes Möbius addition - \(\sinh^{-1}\) is the inverse hyperbolic sine (arcsinh)
The formulation ensures that decision boundaries are geodesic hyperplanes in hyperbolic space.
Examples
>>> manifold = PoincareBall(curvature=1.0) >>> batch_size, dim, n_classes = 32, 10, 5 >>> x = torch.randn(batch_size, dim) * 0.3 >>> x = manifold.project(x) >>> weights = torch.randn(n_classes, dim) >>> points = torch.randn(n_classes, dim) * 0.3 >>> points = manifold.project(points) >>> logits = compute_hyperbolic_mlr_logits(x, weights, points, manifold) >>> probs = torch.softmax(logits, dim=1) # Classification probabilities