"""
Core tensor operations for hyperbolic geometry.
This module implements low-level tensor operations required for numerical
stability in hyperbolic computations, including safe norms and hyperbolic
trigonometric functions.
Functions
---------
norm
Compute vector norm with numerical stability guarantees.
squared_norm
Compute squared norm efficiently.
dot_product
Compute dot product along the last dimension.
tanh, atanh
Numerically stable hyperbolic trigonometric functions.
"""
import torch
from hyptorch._config import NumericalConstants
[docs]
def norm(tensor: torch.Tensor, *, safe: bool = False) -> torch.Tensor:
"""
Compute the L2 norm of tensors along the last dimension.
This function computes the Euclidean norm (L2 norm) of input tensors,
with an optional safety mechanism to prevent division by zero in
subsequent operations.
Parameters
----------
tensor : torch.Tensor
Input tensor of any shape (..., dim).
safe : bool, optional
If True, clamps the norm to be at least MIN_NORM_THRESHOLD to
prevent numerical issues. Default is False. This is keyword-only.
Returns
-------
torch.Tensor
L2 norm along the last dimension. Shape (..., 1).
The last dimension is kept for broadcasting compatibility.
"""
norm = torch.linalg.norm(tensor, dim=-1, keepdim=True)
if safe:
return torch.clamp_min(norm, NumericalConstants.MIN_NORM_THRESHOLD)
return norm
[docs]
def squared_norm(tensor: torch.Tensor) -> torch.Tensor:
"""
Compute the squared L2 norm of tensors along the last dimension.
This function computes the squared Euclidean norm, which is more
efficient than computing the norm when the square root is not needed.
Parameters
----------
tensor : torch.Tensor
Input tensor of any shape
Returns
-------
torch.Tensor
Squared L2 norm along the last dimension.
The last dimension is kept for broadcasting compatibility.
"""
return torch.sum(tensor.pow(2), dim=-1, keepdim=True)
[docs]
def dot_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Compute the dot product between tensors along the last dimension.
This function computes the inner product (dot product) between
corresponding vectors in two tensors, handling arbitrary batch dimensions.
Parameters
----------
x, y : torch.Tensor
Input tensors. Must have the same shape.
Returns
-------
torch.Tensor
Dot product along the last dimension. Shape (..., 1).
The last dimension is kept for broadcasting compatibility.
"""
return torch.sum(x * y, dim=-1, keepdim=True)
[docs]
def tanh(x: torch.Tensor) -> torch.Tensor:
"""
Compute the hyperbolic tangent of a tensor element-wise.
This function is a wrapper around torch.tanh to ensure numerical stability.
Parameters
----------
x : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Element-wise hyperbolic tangent of the input tensor.
"""
return torch.tanh(
torch.clamp(x, min=NumericalConstants.TANH_CLAMP_MIN, max=NumericalConstants.TANH_CLAMP_MAX)
)
[docs]
def atanh(x: torch.Tensor) -> torch.Tensor:
"""
Compute the inverse hyperbolic tangent of a tensor element-wise.
This function is a wrapper around torch.atanh to ensure numerical stability.
Parameters
----------
x : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Element-wise inverse hyperbolic tangent of the input tensor.
"""
return torch.atanh(
torch.clamp(x, min=NumericalConstants.ATANH_CLAMP_MIN, max=NumericalConstants.ATANH_CLAMP_MAX)
)