Tensor
Autograd
Automatic differentiation for hyperbolic geometry.
This module provides custom autograd functions for Riemannian gradient computation in hyperbolic space, ensuring correct gradient flow during backpropagation on manifolds.
Classes
- RiemannianGradient
Custom autograd function for Riemannian gradient scaling.
Functions
- apply_riemannian_gradient
Apply Riemannian gradient correction to tensors on the Poincaré ball.
- class RiemannianGradient(*args, **kwargs)[source]
Bases:
FunctionCustom autograd function for Riemannian gradient computation in hyperbolic space.
This class implements a custom backward pass that scales Euclidean gradients to Riemannian gradients appropriate for optimization on the Poincaré ball model. It’s essential for correct gradient-based optimization in hyperbolic neural networks.
Notes
In hyperbolic geometry, the metric tensor is different from Euclidean space, requiring gradient scaling to account for the curvature of the model. The Poincaré ball has a conformal metric, meaning the metric tensor is a scalar multiple of the identity matrix.
The scaling factor at point x is:
\[\text{scale} = \frac{(1 - c\|x\|^2)^2}{4}\]This ensures gradients respect the hyperbolic geometry during backpropagation.
This is implemented as a custom autograd function to efficiently handle the gradient transformation during the backward pass without affecting the forward computation.
- static forward(ctx, x, curvature)[source]
Forward pass - returns input unchanged but saves context for backward.
- Parameters:
ctx (FunctionCtx) – Context object for storing information needed in backward pass.
x (torch.Tensor) – Point on the Poincaré ball.
curvature (torch.Tensor) – Curvature of the hyperbolic space (positive scalar).
- Returns:
The input x unchanged.
- Return type:
torch.Tensor
Notes
The forward pass is an identity operation. The gradient scaling only affects the backward pass, allowing this function to be transparently inserted into computational graphs.
- static backward(ctx, grad_output)[source]
Backward pass - scales Euclidean gradient to Riemannian gradient.
- Parameters:
ctx (FunctionCtx) – Context with saved tensors from forward pass.
grad_output (torch.Tensor) – Euclidean gradient flowing backward.
- Returns:
Scaled Riemannian gradient with respect to x.
None for the curvature gradient (no gradient needed).
- Return type:
tuple[torch.Tensor, None]
Notes
The Riemannian gradient is computed as:
\[\nabla_R f(x) = \frac{(1 - c\|x\|^2)^2}{4} \nabla_E f(x)\]where \(\nabla_R\) is the Riemannian gradient and \(\nabla_E\) is the Euclidean gradient.
This scaling ensures that gradient descent steps respect the hyperbolic geometry, moving along geodesics rather than straight lines.
- apply_riemannian_gradient(x, curvature)[source]
Apply Riemannian gradient transformation for hyperbolic optimization.
This function wraps the RiemannianGradient autograd function, providing a convenient interface for applying gradient scaling in hyperbolic neural networks. It ensures that gradient-based optimization respects the geometry of the Poincaré ball.
- Parameters:
x (torch.Tensor) – Point on the Poincaré ball where gradient scaling should be applied.
curvature (torch.Tensor) – Positive curvature parameter of the hyperbolic space. Scalar tensor.
- Returns:
The input x with Riemannian gradient computation attached. Forward pass returns x unchanged, but backward pass will scale gradients appropriately.
- Return type:
torch.Tensor
Notes
This function should be applied to points on the Poincaré ball when they are created or after projecting to the manifold. It’s particularly important for:
Points created by mapping from Euclidean space (e.g., in ToPoincare)
Learned parameters that live on the manifold
After manifold projections that may affect gradient flow
The gradient scaling factor \(\frac{(1 - c\|x\|^2)^2}{4}\) approaches zero as x approaches the boundary of the Poincaré ball, reflecting the infinite distance to the boundary in hyperbolic geometry.
Operations
Core tensor operations for hyperbolic geometry.
This module implements low-level tensor operations required for numerical stability in hyperbolic computations, including safe norms and hyperbolic trigonometric functions.
Functions
- norm
Compute vector norm with numerical stability guarantees.
- squared_norm
Compute squared norm efficiently.
- dot_product
Compute dot product along the last dimension.
- tanh, atanh
Numerically stable hyperbolic trigonometric functions.
- norm(tensor, *, safe=False)[source]
Compute the L2 norm of tensors along the last dimension.
This function computes the Euclidean norm (L2 norm) of input tensors, with an optional safety mechanism to prevent division by zero in subsequent operations.
- Parameters:
tensor (torch.Tensor) – Input tensor of any shape (…, dim).
safe (bool, optional) – If True, clamps the norm to be at least MIN_NORM_THRESHOLD to prevent numerical issues. Default is False. This is keyword-only.
- Returns:
L2 norm along the last dimension. Shape (…, 1). The last dimension is kept for broadcasting compatibility.
- Return type:
torch.Tensor
- squared_norm(tensor)[source]
Compute the squared L2 norm of tensors along the last dimension.
This function computes the squared Euclidean norm, which is more efficient than computing the norm when the square root is not needed.
- Parameters:
tensor (torch.Tensor) – Input tensor of any shape
- Returns:
Squared L2 norm along the last dimension. The last dimension is kept for broadcasting compatibility.
- Return type:
torch.Tensor
- dot_product(x, y)[source]
Compute the dot product between tensors along the last dimension.
This function computes the inner product (dot product) between corresponding vectors in two tensors, handling arbitrary batch dimensions.
- Parameters:
x (torch.Tensor) – Input tensors. Must have the same shape.
y (torch.Tensor) – Input tensors. Must have the same shape.
- Returns:
Dot product along the last dimension. Shape (…, 1). The last dimension is kept for broadcasting compatibility.
- Return type:
torch.Tensor
- tanh(x)[source]
Compute the hyperbolic tangent of a tensor element-wise.
This function is a wrapper around torch.tanh to ensure numerical stability.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Element-wise hyperbolic tangent of the input tensor.
- Return type:
torch.Tensor
- atanh(x)[source]
Compute the inverse hyperbolic tangent of a tensor element-wise.
This function is a wrapper around torch.atanh to ensure numerical stability.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Element-wise inverse hyperbolic tangent of the input tensor.
- Return type:
torch.Tensor