NN

Functional

Functional interface for hyperbolic neural network operations.

This module provides stateless functions for hyperbolic computations, following PyTorch’s functional API pattern (similar to torch.nn.functional).

Functions

compute_hyperbolic_mlr_logits

Compute logits for hyperbolic multinomial logistic regression.

compute_hyperbolic_mlr_logits(x, weights, class_points, manifold)[source]

Compute logits for hyperbolic multinomial logistic regression (MLR).

This function implements the hyperbolic generalization of softmax/MLR, computing class logits for input points based on their hyperbolic distances to learned class representatives in the Poincaré ball.

Parameters:
  • x (torch.Tensor) – Input points on the Poincaré ball. Shape (batch_size, dim).

  • weights (torch.Tensor) – Weight vectors (a-values) for each class, scaled by conformal factor. Shape (n_classes, dim).

  • class_points (torch.Tensor) – Class representatives (p-values) on the Poincaré ball. Shape (n_classes, dim).

  • manifold (MobiusManifold) – The hyperbolic manifold instance. Currently only PoincareBall is supported.

Returns:

Logits for each input point and class. Shape (batch_size, n_classes). Can be passed to standard softmax for classification probabilities.

Return type:

torch.Tensor

Raises:

NotImplementedError – If manifold is not an instance of PoincareBall.

Notes

The hyperbolic MLR generalizes logistic regression to hyperbolic space. For each class k with representative \(p_k\) and weights \(a_k\), the logit for an input point \(x\) is:

\[\text{logit}_k(x) = \frac{\lambda_{p_k}^c \|a_k\|}{\sqrt{c}} \sinh^{-1}\left(\frac{2\sqrt{c} \langle a_k, -p_k \oplus_c x \rangle} {(1 - c\|-p_k \oplus_c x\|^2)\|a_k\|}\right)\]

where: - \(\lambda_{p_k}^c = \frac{2}{1 - c\|p_k\|^2}\) is the conformal factor - \(\oplus_c\) denotes Möbius addition - \(\sinh^{-1}\) is the inverse hyperbolic sine (arcsinh)

The formulation ensures that decision boundaries are geodesic hyperplanes in hyperbolic space.

Examples

>>> manifold = PoincareBall(curvature=1.0)
>>> batch_size, dim, n_classes = 32, 10, 5
>>> x = torch.randn(batch_size, dim) * 0.3
>>> x = manifold.project(x)
>>> weights = torch.randn(n_classes, dim)
>>> points = torch.randn(n_classes, dim) * 0.3
>>> points = manifold.project(points)
>>> logits = compute_hyperbolic_mlr_logits(x, weights, points, manifold)
>>> probs = torch.softmax(logits, dim=1)  # Classification probabilities

Layers

Modules