Skip to content

sonnix.utils

sonnix.utils.device

Contain utility functions to manage the device(s) of a torch.nn.Module.

sonnix.utils.device.get_module_device

get_module_device(module: Module) -> device

Get the device used by this module.

This function assumes the module uses a single device. If the module uses several devices, you should use get_module_devices. It returns torch.device('cpu') if the model does not have parameters.

Parameters:

Name Type Description Default
module Module

The module.

required

Returns:

Type Description
device

The device

Example
>>> import torch
>>> from sonnix.utils.device import get_module_device
>>> get_module_device(torch.nn.Linear(4, 6))
device(type='cpu')

sonnix.utils.device.get_module_devices

get_module_devices(module: Module) -> list[device]

Get the devices used in a module.

Parameters:

Name Type Description Default
module Module

The module.

required

Returns:

Type Description
list[device]

The list of torch.devices used in the module.

Example
>>> import torch
>>> from sonnix.utils.device import get_module_devices
>>> get_module_devices(torch.nn.Linear(4, 6))
[device(type='cpu')]

sonnix.utils.device.is_module_on_device

is_module_on_device(module: Module, device: device) -> bool

Indicate if all the parameters of a module are on the specified device.

Parameters:

Name Type Description Default
module Module

The module.

required
device device

The device.

required

Returns:

Type Description
bool

True if all the parameters of the module are on the specified device, otherwise False.

Example
>>> import torch
>>> from sonnix.utils.device import is_module_on_device
>>> is_module_on_device(torch.nn.Linear(4, 6), torch.device("cpu"))
True

sonnix.utils.imports

Implement some utility functions to manage optional dependencies.

sonnix.utils.imports.check_objectory

check_objectory() -> None

Check if the objectory package is installed.

Raises:

Type Description
RuntimeError

if the objectory package is not installed.

Example
>>> from sonnix.utils.imports import check_objectory
>>> check_objectory()

sonnix.utils.imports.is_objectory_available

is_objectory_available() -> bool

Indicate if the objectory package is installed or not.

Returns:

Type Description
bool

True if objectory is available otherwise False.

Example
>>> from sonnix.utils.imports import is_objectory_available
>>> is_objectory_available()

sonnix.utils.imports.objectory_available

objectory_available(
    fn: Callable[..., Any],
) -> Callable[..., Any]

Implement a decorator to execute a function only if objectory package is installed.

Parameters:

Name Type Description Default
fn Callable[..., Any]

Specifies the function to execute.

required

Returns:

Type Description
Callable[..., Any]

A wrapper around fn if objectory package is installed, otherwise None.

Example
>>> from sonnix.utils.imports import objectory_available
>>> @objectory_available
... def my_function(n: int = 0) -> int:
...     return 42 + n
...
>>> my_function()

sonnix.utils.imports.raise_error_objectory_missing

raise_error_objectory_missing() -> NoReturn

Raise a RuntimeError to indicate the objectory package is missing.

sonnix.utils.iterator

Contain iterators on torch.nn.Module.

sonnix.utils.iterator.get_named_modules

get_named_modules(
    module: Module, depth: int = 0
) -> Generator[tuple[str, Module]]

Return an iterator over the modules, yielding both the name of the module as well as the module itself.

Parameters:

Name Type Description Default
module Module

The input module.

required
depth int

The maximum depth of module to yield.

0

Returns:

Type Description
Generator[tuple[str, Module]]

The iterator over the modules and their names.

Example
>>> import torch
>>> from sonnix.utils.iterator import get_named_modules
>>> module = torch.nn.Linear(4, 6)
>>> named_modules = list(get_named_modules(module))
>>> named_modules
[('[root]', Linear(in_features=4, out_features=6, bias=True))]
>>> module = torch.nn.Sequential(
...     torch.nn.Linear(4, 6), torch.nn.ReLU(), torch.nn.Linear(6, 3)
... )
>>> named_modules = list(get_named_modules(module))
>>> named_modules
[('[root]', Sequential(
  (0): Linear(in_features=4, out_features=6, bias=True)
  (1): ReLU()
  (2): Linear(in_features=6, out_features=3, bias=True)
))]
>>> named_modules = list(get_named_modules(module, depth=1))
>>> named_modules
[('[root]', Sequential(
  (0): Linear(in_features=4, out_features=6, bias=True)
  (1): ReLU()
  (2): Linear(in_features=6, out_features=3, bias=True)
)),
('0', Linear(in_features=4, out_features=6, bias=True)),
('1', ReLU()), ('2', Linear(in_features=6, out_features=3, bias=True))]

sonnix.utils.loss

Contain utility functions to check if the loss is decreasing.

sonnix.utils.loss.is_loss_decreasing

is_loss_decreasing(
    module: Module,
    criterion: Module | Callable[[Tensor, Tensor], Tensor],
    optimizer: Optimizer,
    feature: Tensor,
    target: Tensor,
    num_iterations: int = 1,
    random_seed: int = 10772155803920552556,
) -> bool

Check if the loss decreased after some iterations.

Parameters:

Name Type Description Default
module Module

The module to test. The module must have a single input tensor and a single output tensor.

required
criterion Module | Callable[[Tensor, Tensor], Tensor]

The criterion to test.

required
optimizer Optimizer

The optimizer to update the weights of the model.

required
feature Tensor

The input of the module.

required
target Tensor

The target used to compute the loss.

required
num_iterations int

The number of optimization steps.

1
random_seed int

The random seed to make the function deterministic if the module contains randomness.

10772155803920552556

Returns:

Type Description
bool

True if the loss decreased after some iterations, otherwise False.

Example
>>> import torch
>>> from torch import nn
>>> from sonnix.utils.loss import is_loss_decreasing
>>> module = nn.Linear(4, 2)
>>> is_loss_decreasing(
...     module=module,
...     criterion=nn.MSELoss(),
...     optimizer=SGD(module.parameters(), lr=0.01),
...     feature=torch.rand(4, 4),
...     target=torch.rand(4, 2),
... )
True

sonnix.utils.loss.is_loss_decreasing_with_adam

is_loss_decreasing_with_adam(
    module: Module,
    criterion: Module | Callable[[Tensor, Tensor], Tensor],
    feature: Tensor,
    target: Tensor,
    lr: float = 0.0003,
    num_iterations: int = 1,
    random_seed: int = 10772155803920552556,
) -> bool

Check if the loss decreased after some iterations.

The module is trained with the Adam optimizer.

Parameters:

Name Type Description Default
module Module

The module to test. The module must have a single input tensor and a single output tensor.

required
criterion Module | Callable[[Tensor, Tensor], Tensor]

The criterion to test.

required
feature Tensor

The input of the module.

required
target Tensor

The target used to compute the loss.

required
lr float

The learning rate.

0.0003
num_iterations int

The number of optimization steps.

1
random_seed int

The random seed to make the function deterministic if the module contains randomness.

10772155803920552556

Returns:

Type Description
bool

True if the loss decreased after some iterations, otherwise False.

Example
>>> import torch
>>> from torch import nn
>>> from sonnix.utils.loss import is_loss_decreasing_with_adam
>>> is_loss_decreasing_with_adam(
...     module=nn.Linear(4, 2),
...     criterion=nn.MSELoss(),
...     feature=torch.rand(4, 4),
...     target=torch.rand(4, 2),
...     lr=0.0003,
... )
True

sonnix.utils.loss.is_loss_decreasing_with_sgd

is_loss_decreasing_with_sgd(
    module: Module,
    criterion: Module | Callable[[Tensor, Tensor], Tensor],
    feature: Tensor,
    target: Tensor,
    lr: float = 0.01,
    num_iterations: int = 1,
    random_seed: int = 10772155803920552556,
) -> bool

Check if the loss decreased after some iterations.

The module is trained with the torch.optim.SGD optimizer.

Parameters:

Name Type Description Default
module Module

The module to test. The module must have a single input tensor and a single output tensor.

required
criterion Module | Callable[[Tensor, Tensor], Tensor]

The criterion to test.

required
feature Tensor

The input of the module.

required
target Tensor

The target used to compute the loss.

required
num_iterations int

The number of optimization steps.

1
lr float

The learning rate.

0.01
random_seed int

The random seed to make the function deterministic if the module contains randomness.

10772155803920552556

Returns:

Type Description
bool

True if the loss decreased after some iterations, otherwise False.

Example
>>> import torch
>>> from torch import nn
>>> from sonnix.utils.loss import is_loss_decreasing_with_adam
>>> is_loss_decreasing_with_adam(
...     module=nn.Linear(4, 2),
...     criterion=nn.MSELoss(),
...     feature=torch.rand(4, 4),
...     target=torch.rand(4, 2),
...     lr=0.01,
... )
True

sonnix.utils.mode

Contain utility functions to manage the mode of a torch.nn.Module.

sonnix.utils.mode.module_mode

module_mode(module: Module) -> Generator[None]

Implement a context manager that restores the mode (train or eval) of every submodule individually.

Parameters:

Name Type Description Default
module Module

The module to restore the mode.

required
Example
>>> import torch
>>> from sonnix.utils.mode import module_mode
>>> module = torch.nn.ModuleDict(
...     {"module1": torch.nn.Linear(4, 6), "module2": torch.nn.Linear(2, 4).eval()}
... )
>>> print(module["module1"].training, module["module2"].training)
True False
>>> with module_mode(module):
...     module.eval()
...     print(module["module1"].training, module["module2"].training)
...
ModuleDict(
  (module1): Linear(in_features=4, out_features=6, bias=True)
  (module2): Linear(in_features=2, out_features=4, bias=True)
)
False False
>>> print(module["module1"].training, module["module2"].training)
True False

sonnix.utils.mode.top_module_mode

top_module_mode(module: Module) -> Generator[None]

Implement a context manager that restores the mode (train or eval) of a given module.

This context manager only restores the mode at the top-level.

Parameters:

Name Type Description Default
module Module

The module to restore the mode.

required
Example
>>> import torch
>>> from sonnix.utils.mode import top_module_mode
>>> module = torch.nn.Linear(4, 6)
>>> print(module.training)
True
>>> with top_module_mode(module):
...     module.eval()
...     print(module.training)
...
Linear(in_features=4, out_features=6, bias=True)
False
>>> print(module.training)
True

sonnix.utils.params

Contain utility functions to analyze and manage torch.nn.Module parameters.

sonnix.utils.params.freeze_module

freeze_module(module: Module) -> None

Freeze the parameters of the given module.

Parameters:

Name Type Description Default
module Module

The module to freeze.

required
Example
>>> import torch
>>> from sonnix.utils.params import freeze_module
>>> module = torch.nn.Linear(4, 6)
>>> freeze_module(module)
>>> for name, param in module.named_parameters():
...     print(name, param.requires_grad)
...
weight False
bias False

sonnix.utils.params.has_learnable_parameters

has_learnable_parameters(module: Module) -> bool

Indicate if the module has learnable parameters.

Parameters:

Name Type Description Default
module Module

The module to test.

required

Returns:

Type Description
bool

True if the module has at least one learnable parameter, False otherwise.

Example
>>> import torch
>>> from sonnix.utils.params import has_learnable_parameters, freeze_module
>>> has_learnable_parameters(torch.nn.Linear(4, 6))
True
>>> module = torch.nn.Linear(4, 6)
>>> freeze_module(module)
>>> has_learnable_parameters(module)
False
>>> has_learnable_parameters(torch.nn.Identity())
False

sonnix.utils.params.has_parameters

has_parameters(module: Module) -> bool

Indicate if the module has parameters.

Parameters:

Name Type Description Default
module Module

The module to test.

required

Returns:

Type Description
bool

True if the module has at least one parameter, False otherwise.

Example
>>> import torch
>>> from sonnix.utils.params import has_parameters
>>> has_parameters(torch.nn.Linear(4, 6))
True
>>> has_parameters(torch.nn.Identity())
False

sonnix.utils.params.num_learnable_parameters

num_learnable_parameters(module: Module) -> int

Return the number of learnable parameters in the module.

Parameters:

Name Type Description Default
module Module

The module to compute the number of learnable parameters.

required

Returns:

Name Type Description
int int

The number of learnable parameters.

Example
>>> import torch
>>> from sonnix.utils.params import num_learnable_parameters
>>> num_learnable_parameters(torch.nn.Linear(4, 6))
30
>>> module = torch.nn.Linear(4, 6)
>>> freeze_module(module)
>>> num_learnable_parameters(module)
0
>>> num_learnable_parameters(torch.nn.Identity())
0

sonnix.utils.params.num_parameters

num_parameters(module: Module) -> int

Return the number of parameters in the module.

Parameters:

Name Type Description Default
module Module

The module to compute the number of parameters.

required

Returns:

Type Description
int

The number of parameters.

Example
>>> import torch
>>> from sonnix.utils.params import num_parameters
>>> num_parameters(torch.nn.Linear(4, 6))
30
>>> num_parameters(torch.nn.Identity())
0

sonnix.utils.params.unfreeze_module

unfreeze_module(module: Module) -> None

Unfreeze the parameters of the given module.

Parameters:

Name Type Description Default
module Module

The module to unfreeze.

required
Example
>>> import torch
>>> from sonnix.utils.params import unfreeze_module
>>> module = torch.nn.Linear(4, 6)
>>> unfreeze_module(module)
>>> for name, param in module.named_parameters():
...     print(name, param.requires_grad)
...
weight True
bias True

sonnix.utils.random

Contain utility functions to manage randomness.

sonnix.utils.random.get_random_seed

get_random_seed(seed: int) -> int

Get a random seed.

Parameters:

Name Type Description Default
seed int

A random seed to make the process reproducible.

required

Returns:

Name Type Description
int int

A random seed. The value is between -2 ** 63 and 2 ** 63 - 1.

Example
>>> from sonnix.utils.random import get_random_seed
>>> get_random_seed(44)
6176747449835261347

sonnix.utils.random.get_torch_generator

get_torch_generator(
    random_seed: int = 1,
    device: device | str | None = "cpu",
) -> Generator

Create a torch.Generator initialized with a given seed.

Parameters:

Name Type Description Default
random_seed int

A random seed.

1
device device | str | None

The desired device for the generator.

'cpu'

Returns:

Type Description
Generator

A torch.Generator object.

Example
>>> import torch
>>> from sonnix.utils.random import get_torch_generator
>>> generator = get_torch_generator(42)
>>> torch.rand(2, 4, generator=generator)
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
        [0.3904, 0.6009, 0.2566, 0.7936]])
>>> generator = get_torch_generator(42)
>>> torch.rand(2, 4, generator=generator)
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
        [0.3904, 0.6009, 0.2566, 0.7936]])

sonnix.utils.random.setup_torch_generator

setup_torch_generator(
    generator_or_seed: int | Generator,
) -> Generator

Set up a torch.Generator object.

Parameters:

Name Type Description Default
generator_or_seed int | Generator

A torch.Generator object or a random seed.

required

Returns:

Type Description
Generator

A torch.Generator object.

Example
>>> from sonnix.utils.random import setup_torch_generator
>>> generator = setup_torch_generator(42)
>>> generator
<torch._C.Generator object at 0x...>

sonnix.utils.state_dict

Contain utility functions to manipulate torch.nn.Module's state dict.

sonnix.utils.state_dict.find_module_state_dict

find_module_state_dict(
    state_dict: (
        dict[Any, Any]
        | list[Any]
        | tuple[Any, ...]
        | set[Any]
    ),
    module_keys: set[Any],
) -> dict[Any, Any]

Try to find automatically the part of the state dict related to a module.

The user should specify the set of module's keys: set(module.state_dict().keys()). This function assumes that the set of keys only exists at one location in the state dict. If the set of keys exists at several locations in the state dict, only the first one is returned.

Parameters:

Name Type Description Default
state_dict dict[Any, Any] | list[Any] | tuple[Any, ...] | set[Any]

The state dict. This function is called recursively on this input to find the queried state dict.

required
module_keys set[Any]

The set of module keys.

required

Returns:

Type Description
dict[Any, Any]

The part of the state dict related to a module if it is found, otherwise an empty dict.

Example
>>> import torch
>>> from sonnix.utils.state_dict import find_module_state_dict
>>> state = {
...     "model": {
...         "weight": 42,
...         "network": {
...             "weight": torch.ones(5, 4),
...             "bias": 2 * torch.ones(5),
...         },
...     }
... }
>>> module = torch.nn.Linear(4, 5)
>>> state_dict = find_module_state_dict(state, module_keys=set(module.state_dict().keys()))
>>> state_dict
{'weight': tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]]), 'bias': tensor([2., 2., 2., 2., 2.])}

sonnix.utils.state_dict.load_state_dict_to_module

load_state_dict_to_module(
    state_dict: dict[Any, Any],
    module: Module,
    strict: bool = True,
) -> None

Load a state dict into a given module.

This function will automatically try to find the module state dict in the given state dict.

Parameters:

Name Type Description Default
state_dict dict[Any, Any]

The state dict.

required
module Module

The module. This function changes the weights of this module.

required
strict bool

whether to strictly enforce that the keys in state_dict match the keys returned by this module's torch.nn.Module.state_dict function.

True
Example
>>> import torch
>>> from sonnix.utils.state_dict import load_state_dict_to_module
>>> state = {
...     "model": {
...         "weight": 42,
...         "network": {
...             "weight": torch.ones(5, 4),
...             "bias": 2 * torch.ones(5),
...         },
...     }
... }
>>> module = torch.nn.Linear(4, 5)
>>> load_state_dict_to_module(state, module)
>>> out = module(torch.ones(2, 4))
>>> out
tensor([[6., 6., 6., 6., 6.],
        [6., 6., 6., 6., 6.]], grad_fn=<AddmmBackward0>)