Skip to content

karbonn.utils

Module

karbonn.utils.freeze_module

freeze_module(module: Module) -> None

Freeze the parameters of the given module.

Parameters:

Name Type Description Default
module Module

The module to freeze.

required

Example usage:

>>> import torch
>>> from karbonn.utils import freeze_module
>>> module = torch.nn.Linear(4, 6)
>>> freeze_module(module)
>>> for name, param in module.named_parameters():
...     print(name, param.requires_grad)
...
weight False
bias False

karbonn.utils.get_module_device

get_module_device(module: Module) -> device

Get the device used by this module.

This function assumes the module uses a single device. If the module uses several devices, you should use get_module_devices. It returns torch.device('cpu') if the model does not have parameters.

Parameters:

Name Type Description Default
module Module

The module.

required

Returns:

Type Description
device

The device

Example usage:

>>> import torch
>>> from karbonn.utils import get_module_device
>>> get_module_device(torch.nn.Linear(4, 6))
device(type='cpu')

karbonn.utils.get_module_devices

get_module_devices(module: Module) -> tuple[device, ...]

Get the devices used in a module.

Parameters:

Name Type Description Default
module Module

The module.

required

Returns:

Type Description
tuple[device, ...]

The tuple of torch.devices used in the module.

Example usage:

>>> import torch
>>> from karbonn.utils import get_module_devices
>>> get_module_devices(torch.nn.Linear(4, 6))
(device(type='cpu'),)

karbonn.utils.has_learnable_parameters

has_learnable_parameters(module: Module) -> bool

Indicate if the module has learnable parameters.

Parameters:

Name Type Description Default
module Module

The module to test.

required

Returns:

Type Description
bool

True if the module has at least one learnable parameter, False otherwise.

Example usage:

>>> import torch
>>> from karbonn.utils import has_learnable_parameters, freeze_module
>>> has_learnable_parameters(torch.nn.Linear(4, 6))
True
>>> module = torch.nn.Linear(4, 6)
>>> freeze_module(module)
>>> has_learnable_parameters(module)
False
>>> has_learnable_parameters(torch.nn.Identity())
False

karbonn.utils.has_parameters

has_parameters(module: Module) -> bool

Indicate if the module has parameters.

Parameters:

Name Type Description Default
module Module

The module to test.

required

Returns:

Type Description
bool

True if the module has at least one parameter, False otherwise.

Example usage:

>>> import torch
>>> from karbonn.utils import has_parameters
>>> has_parameters(torch.nn.Linear(4, 6))
True
>>> has_parameters(torch.nn.Identity())
False

karbonn.utils.is_module_on_device

is_module_on_device(module: Module, device: device) -> bool

Indicate if all the parameters of a module are on the specified device.

Parameters:

Name Type Description Default
module Module

The module.

required
device device

The device.

required

Returns:

Type Description
bool

True if all the parameters of the module are on the specified device, otherwise False.

Example usage:

>>> import torch
>>> from karbonn.utils import is_module_on_device
>>> is_module_on_device(torch.nn.Linear(4, 6), torch.device("cpu"))
True

karbonn.utils.module_mode

module_mode(module: Module) -> Generator[None]

Implement a context manager that restores the mode (train or eval) of every submodule individually.

Parameters:

Name Type Description Default
module Module

The module to restore the mode.

required

Example usage:

>>> import torch
>>> from karbonn.utils import module_mode
>>> module = torch.nn.ModuleDict(
...     {"module1": torch.nn.Linear(4, 6), "module2": torch.nn.Linear(2, 4).eval()}
... )
>>> print(module["module1"].training, module["module2"].training)
True False
>>> with module_mode(module):
...     module.eval()
...     print(module["module1"].training, module["module2"].training)
...
ModuleDict(
  (module1): Linear(in_features=4, out_features=6, bias=True)
  (module2): Linear(in_features=2, out_features=4, bias=True)
)
False False
>>> print(module["module1"].training, module["module2"].training)
True False

karbonn.utils.num_learnable_parameters

num_learnable_parameters(module: Module) -> int

Return the number of learnable parameters in the module.

Parameters:

Name Type Description Default
module Module

The module to compute the number of learnable parameters.

required

Returns:

Name Type Description
int int

The number of learnable parameters.

Example usage:

>>> import torch
>>> from karbonn.utils import num_learnable_parameters
>>> num_learnable_parameters(torch.nn.Linear(4, 6))
30
>>> module = torch.nn.Linear(4, 6)
>>> freeze_module(module)
>>> num_learnable_parameters(module)
0
>>> num_learnable_parameters(torch.nn.Identity())
0

karbonn.utils.num_parameters

num_parameters(module: Module) -> int

Return the number of parameters in the module.

Parameters:

Name Type Description Default
module Module

The module to compute the number of parameters.

required

Returns:

Type Description
int

The number of parameters.

Example usage:

>>> import torch
>>> from karbonn.utils import num_parameters
>>> num_parameters(torch.nn.Linear(4, 6))
30
>>> num_parameters(torch.nn.Identity())
0

karbonn.utils.top_module_mode

top_module_mode(module: Module) -> Generator[None]

Implement a context manager that restores the mode (train or eval) of a given module.

This context manager only restores the mode at the top-level.

Parameters:

Name Type Description Default
module Module

The module to restore the mode.

required

Example usage:

>>> import torch
>>> from karbonn.utils import top_module_mode
>>> module = torch.nn.Linear(4, 6)
>>> print(module.training)
True
>>> with top_module_mode(module):
...     module.eval()
...     print(module.training)
...
Linear(in_features=4, out_features=6, bias=True)
False
>>> print(module.training)
True

karbonn.utils.unfreeze_module

unfreeze_module(module: Module) -> None

Unfreeze the parameters of the given module.

Parameters:

Name Type Description Default
module Module

The module to unfreeze.

required

Example usage:

>>> import torch
>>> from karbonn.utils import unfreeze_module
>>> module = torch.nn.Linear(4, 6)
>>> unfreeze_module(module)
>>> for name, param in module.named_parameters():
...     print(name, param.requires_grad)
...
weight True
bias True

Factory

karbonn.utils.create_sequential

create_sequential(
    modules: Sequence[Module | dict],
) -> Sequential

Create a torch.nn.Sequential from a sequence of modules.

Parameters:

Name Type Description Default
modules Sequence[Module | dict]

The sequence of modules or their configuration.

required

Returns:

Type Description
Sequential

The instantiated torch.nn.Sequential object.

Example usage:

>>> from karbonn.utils.factory import create_sequential
>>> seq = create_sequential(
...     [
...         {"_target_": "torch.nn.Linear", "in_features": 4, "out_features": 6},
...         {"_target_": "torch.nn.ReLU"},
...         {"_target_": "torch.nn.Linear", "in_features": 6, "out_features": 6},
...     ]
... )
>>> seq
Sequential(
  (0): Linear(in_features=4, out_features=6, bias=True)
  (1): ReLU()
  (2): Linear(in_features=6, out_features=6, bias=True)
)

karbonn.utils.is_module_config

is_module_config(config: dict) -> bool

Indicate if the input configuration is a configuration for a torch.nn.Module.

This function only checks if the value of the key _target_ is valid. It does not check the other values. If _target_ indicates a function, the returned type hint is used to check the class.

Parameters:

Name Type Description Default
config dict

The configuration to check.

required

Returns:

Type Description
bool

True if the input configuration is a configuration for a torch.nn.Module object, otherwise False..

Example usage:

>>> from karbonn.utils import is_module_config
>>> is_module_config({"_target_": "torch.nn.Identity"})
True

karbonn.utils.setup_module

setup_module(module: Module | dict) -> Module

Set up a torch.nn.Module object.

Parameters:

Name Type Description Default
module Module | dict

The module or its configuration.

required

Returns:

Type Description
Module

The instantiated torch.nn.Module object.

Example usage:

>>> from karbonn.utils import setup_module
>>> linear = setup_module(
...     {"_target_": "torch.nn.Linear", "in_features": 4, "out_features": 6}
... )
>>> linear
Linear(in_features=4, out_features=6, bias=True)

State dict

karbonn.utils.find_module_state_dict

find_module_state_dict(
    state_dict: dict | list | tuple | set, module_keys: set
) -> dict

Try to find automatically the part of the state dict related to a module.

The user should specify the set of module's keys: set(module.state_dict().keys()). This function assumes that the set of keys only exists at one location in the state dict. If the set of keys exists at several locations in the state dict, only the first one is returned.

Parameters:

Name Type Description Default
state_dict dict | list | tuple | set

The state dict. This function is called recursively on this input to find the queried state dict.

required
module_keys set

The set of module keys.

required

Returns:

Type Description
dict

The part of the state dict related to a module if it is found, otherwise an empty dict.

Example usage:

>>> import torch
>>> from karbonn.utils import find_module_state_dict
>>> state = {
...     "model": {
...         "weight": 42,
...         "network": {
...             "weight": torch.ones(5, 4),
...             "bias": 2 * torch.ones(5),
...         },
...     }
... }
>>> module = torch.nn.Linear(4, 5)
>>> state_dict = find_module_state_dict(state, module_keys=set(module.state_dict().keys()))
>>> state_dict
{'weight': tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]]), 'bias': tensor([2., 2., 2., 2., 2.])}

karbonn.utils.load_state_dict_to_module

load_state_dict_to_module(
    state_dict: dict, module: Module, strict: bool = True
) -> None

Load a state dict into a given module.

This function will automatically try to find the module state dict in the given state dict.

Parameters:

Name Type Description Default
state_dict dict

The state dict.

required
module Module

The module. This function changes the weights of this module.

required
strict bool

whether to strictly enforce that the keys in state_dict match the keys returned by this module's torch.nn.Module.state_dict function.

True

Example usage:

>>> import torch
>>> from karbonn.utils import load_state_dict_to_module
>>> state = {
...     "model": {
...         "weight": 42,
...         "network": {
...             "weight": torch.ones(5, 4),
...             "bias": 2 * torch.ones(5),
...         },
...     }
... }
>>> module = torch.nn.Linear(4, 5)
>>> load_state_dict_to_module(state, module)
>>> out = module(torch.ones(2, 4))
>>> out
tensor([[6., 6., 6., 6., 6.],
        [6., 6., 6., 6., 6.]], grad_fn=<AddmmBackward0>)

Loss

karbonn.utils.is_loss_decreasing

is_loss_decreasing(
    module: Module,
    criterion: Module | Callable[[Tensor, Tensor], Tensor],
    optimizer: Optimizer,
    feature: Tensor,
    target: Tensor,
    num_iterations: int = 1,
    random_seed: int = 10772155803920552556,
) -> bool

Check if the loss decreased after some iterations.

Parameters:

Name Type Description Default
module Module

The module to test. The module must have a single input tensor and a single output tensor.

required
criterion Module | Callable[[Tensor, Tensor], Tensor]

The criterion to test.

required
optimizer Optimizer

The optimizer to update the weights of the model.

required
feature Tensor

The input of the module.

required
target Tensor

The target used to compute the loss.

required
num_iterations int

The number of optimization steps.

1
random_seed int

The random seed to make the function deterministic if the module contains randomness.

10772155803920552556

Returns:

Type Description
bool

True if the loss decreased after some iterations, otherwise False.

Example usage:

>>> import torch
>>> from torch import nn
>>> from karbonn.utils import is_loss_decreasing
>>> module = nn.Linear(4, 2)
>>> is_loss_decreasing(
...     module=module,
...     criterion=nn.MSELoss(),
...     optimizer=SGD(module.parameters(), lr=0.01),
...     feature=torch.rand(4, 4),
...     target=torch.rand(4, 2),
... )
True

karbonn.utils.is_loss_decreasing_with_adam

is_loss_decreasing_with_adam(
    module: Module,
    criterion: Module | Callable[[Tensor, Tensor], Tensor],
    feature: Tensor,
    target: Tensor,
    lr: float = 0.0003,
    num_iterations: int = 1,
    random_seed: int = 10772155803920552556,
) -> bool

Check if the loss decreased after some iterations.

The module is trained with the Adam optimizer.

Parameters:

Name Type Description Default
module Module

The module to test. The module must have a single input tensor and a single output tensor.

required
criterion Module | Callable[[Tensor, Tensor], Tensor]

The criterion to test.

required
feature Tensor

The input of the module.

required
target Tensor

The target used to compute the loss.

required
lr float

The learning rate.

0.0003
num_iterations int

The number of optimization steps.

1
random_seed int

The random seed to make the function deterministic if the module contains randomness.

10772155803920552556

Returns:

Type Description
bool

True if the loss decreased after some iterations, otherwise False.

Example usage:

>>> import torch
>>> from torch import nn
>>> from karbonn.utils import is_loss_decreasing_with_adam
>>> is_loss_decreasing_with_adam(
...     module=nn.Linear(4, 2),
...     criterion=nn.MSELoss(),
...     feature=torch.rand(4, 4),
...     target=torch.rand(4, 2),
...     lr=0.0003,
... )
True

karbonn.utils.is_loss_decreasing_with_sgd

is_loss_decreasing_with_sgd(
    module: Module,
    criterion: Module | Callable[[Tensor, Tensor], Tensor],
    feature: Tensor,
    target: Tensor,
    lr: float = 0.01,
    num_iterations: int = 1,
    random_seed: int = 10772155803920552556,
) -> bool

Check if the loss decreased after some iterations.

The module is trained with the torch.optim.SGD optimizer.

Parameters:

Name Type Description Default
module Module

The module to test. The module must have a single input tensor and a single output tensor.

required
criterion Module | Callable[[Tensor, Tensor], Tensor]

The criterion to test.

required
feature Tensor

The input of the module.

required
target Tensor

The target used to compute the loss.

required
num_iterations int

The number of optimization steps.

1
lr float

The learning rate.

0.01
random_seed int

The random seed to make the function deterministic if the module contains randomness.

10772155803920552556

Returns:

Type Description
bool

True if the loss decreased after some iterations, otherwise False.

Example usage:

>>> import torch
>>> from torch import nn
>>> from karbonn.utils import is_loss_decreasing_with_adam
>>> is_loss_decreasing_with_adam(
...     module=nn.Linear(4, 2),
...     criterion=nn.MSELoss(),
...     feature=torch.rand(4, 4),
...     target=torch.rand(4, 2),
...     lr=0.01,
... )
True

Input/output sizes

karbonn.utils.size.find_in_features

find_in_features(module: Module) -> list[int]

Find the input feature sizes of a given module.

Parameters:

Name Type Description Default
module Module

The module.

required

Returns:

Type Description
list[int]

The input feature sizes.

Raises:

Type Description
SizeNotFound

if the input feature sizes could not be found.

Example usage:

>>> import torch
>>> from karbonn.utils.size import find_in_features
>>> module = torch.nn.Linear(4, 6)
>>> in_features = find_in_features(module)
>>> in_features
[4]
>>> module = torch.nn.Bilinear(in1_features=4, in2_features=6, out_features=8)
>>> in_features = find_in_features(module)
>>> in_features
[4, 6]

karbonn.utils.size.find_out_features

find_out_features(module: Module) -> list[int]

Find the output feature sizes of a given module.

Parameters:

Name Type Description Default
module Module

The module.

required

Returns:

Type Description
list[int]

The output feature sizes.

Raises:

Type Description
SizeNotFound

if the output feature sizes could not be found.

Example usage:

>>> import torch
>>> from karbonn.utils.size import find_out_features
>>> module = torch.nn.Linear(4, 6)
>>> out_features = find_out_features(module)
>>> out_features
[6]
>>> module = torch.nn.Bilinear(in1_features=4, in2_features=6, out_features=8)
>>> out_features = find_out_features(module)
>>> out_features
[8]

karbonn.utils.size.BaseSizeFinder

Bases: ABC, Generic[T]

Define the base class to find the input or output feature size of a module.

Example usage:

>>> import torch
>>> from karbonn.utils.size import AutoSizeFinder
>>> size_finder = AutoSizeFinder()
>>> module = torch.nn.Linear(4, 6)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[6]

karbonn.utils.size.BaseSizeFinder.find_in_features abstractmethod

find_in_features(module: T) -> list[int]

Find the input feature sizes of a given module.

Parameters:

Name Type Description Default
module T

The module.

required

Returns:

Type Description
list[int]

The input feature sizes.

Raises:

Type Description
SizeNotFound

if the input feature size could not be found.

Example usage:

>>> import torch
>>> from karbonn.utils.size import AutoSizeFinder
>>> module = torch.nn.Linear(4, 6)
>>> size_finder = AutoSizeFinder()
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4]
>>> module = torch.nn.Bilinear(in1_features=4, in2_features=6, out_features=8)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4, 6]

karbonn.utils.size.BaseSizeFinder.find_out_features abstractmethod

find_out_features(module: T) -> list[int]

Find the output feature sizes of a given module.

Parameters:

Name Type Description Default
module T

The module.

required

Returns:

Type Description
list[int]

The output feature sizes.

Raises:

Type Description
SizeNotFoundError

if the output feature size could not be found.

Example usage:

>>> import torch
>>> from karbonn.utils.size import AutoSizeFinder
>>> module = torch.nn.Linear(4, 6)
>>> size_finder = AutoSizeFinder()
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[6]
>>> module = torch.nn.Bilinear(in1_features=4, in2_features=6, out_features=8)
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[8]

karbonn.utils.size.AutoSizeFinder

Bases: BaseSizeFinder

Implement a size finder that automatically finds the size based on the module type.

Example usage:

>>> import torch
>>> from karbonn.utils.size import AutoSizeFinder
>>> size_finder = AutoSizeFinder()
>>> module = torch.nn.Linear(4, 6)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[6]

karbonn.utils.size.AutoSizeFinder.add_size_finder classmethod

add_size_finder(
    module_type: type[Module],
    size_finder: BaseSizeFinder,
    exist_ok: bool = False,
) -> None

Add a size finder for a given module type.

Parameters:

Name Type Description Default
module_type type[Module]

The module type.

required
size_finder BaseSizeFinder

The size finder to use for the given module type.

required
exist_ok bool

If False, RuntimeError is raised if the data type already exists. This parameter should be set to True to overwrite the size finder for a module type.

False

Raises:

Type Description
RuntimeError

if a size finder is already registered for the module type and exist_ok=False.

Example usage:

>>> from torch import nn
>>> from karbonn.utils.size import AutoSizeFinder, LinearSizeFinder
>>> AutoSizeFinder.add_size_finder(nn.Linear, LinearSizeFinder(), exist_ok=True)

karbonn.utils.size.AutoSizeFinder.find_size_finder classmethod

find_size_finder(
    module_type: type[Module],
) -> BaseSizeFinder

Find the size finder associated to a module type.

Parameters:

Name Type Description Default
module_type type[Module]

The module type.

required

Returns:

Type Description
BaseSizeFinder

The size finder associated to the module type.

Example usage:

>>> from torch import nn
>>> from karbonn.utils.size import AutoSizeFinder
>>> AutoSizeFinder.find_size_finder(nn.Linear)
LinearSizeFinder()
>>> AutoSizeFinder.find_size_finder(nn.Bilinear)
BilinearSizeFinder()

karbonn.utils.size.AutoSizeFinder.has_size_finder classmethod

has_size_finder(module_type: type) -> bool

Indicate if a size finder is registered for the given module type.

Parameters:

Name Type Description Default
module_type type

The module type.

required

Returns:

Type Description
bool

True if a size finder is registered, otherwise False.

Example usage:

>>> from torch import nn
>>> from karbonn.utils.size import AutoSizeFinder
>>> AutoSizeFinder.has_size_finder(nn.Linear)
True
>>> AutoSizeFinder.has_size_finder(str)
False

karbonn.utils.size.register_size_finders

register_size_finders() -> None

Register size finders to AutoSizeFinder.

Example usage:

>>> from karbonn.utils.size import AutoSizeFinder, register_size_finders
>>> register_size_finders()
>>> size_finder = AutoSizeFinder()
>>> size_finder
AutoSizeFinder(
  ...
)

karbonn.utils.size.BatchNormSizeFinder

Bases: BaseSizeFinder[Module]

Implement a size finder for BatchNorm layers like torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, or torch.nn.SyncBatchNorm.

This module size finder assumes the module has a single input and output, and the input size is given by the attribute num_features and the output size is given by the attribute num_features.

Example usage:

>>> import torch
>>> from karbonn.utils.size import BatchNormSizeFinder
>>> size_finder = BatchNormSizeFinder()
>>> module = torch.nn.BatchNorm1d(num_features=6)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[6]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[6]

karbonn.utils.size.BilinearSizeFinder

Bases: BaseSizeFinder

Implement a size finder for torch.nn.Bilinear layer or similar layers.

This module size finder assumes the module has two inputs and one output. The input sizes are given by the attribute in1_features and in2_features and the output size is given by the attribute out_features.

Example usage:

>>> import torch
>>> from karbonn.utils.size import BilinearSizeFinder
>>> size_finder = BilinearSizeFinder()
>>> module = torch.nn.Bilinear(in1_features=4, in2_features=2, out_features=6)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4, 2]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[6]

karbonn.utils.size.ConvolutionSizeFinder

Bases: BaseSizeFinder[Module]

Implement a size finder for convolution layers like torch.nn.ConvNd and torch.nn.ConvTransposeNd.

This module size finder assumes the module has a single input and output, and the input size is given by the attribute in_channels and the output size is given by the attribute out_channels.

Example usage:

>>> import torch
>>> from karbonn.utils.size import ConvolutionSizeFinder
>>> size_finder = ConvolutionSizeFinder()
>>> module = torch.nn.Conv2d(in_channels=4, out_channels=6, kernel_size=1)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[6]

karbonn.utils.size.EmbeddingSizeFinder

Bases: BaseSizeFinder[Module]

Implement a size finder for embedding layers like torch.nn.Embedding.

This module size finder assumes the module has a single input and output, and the input size is always 1, and the output size is given by the attribute embedding_dim.

Example usage:

>>> import torch
>>> from karbonn.utils.size import EmbeddingSizeFinder
>>> size_finder = EmbeddingSizeFinder()
>>> module = torch.nn.Embedding(num_embeddings=5, embedding_dim=6)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[1]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[6]

karbonn.utils.size.GroupNormSizeFinder

Bases: BaseSizeFinder[Module]

Implement a size finder for Group Normalization layers like torch.nn.GroupNorm.

This module size finder assumes the module has a single input and output, and the input size is given by the attribute num_channels and the output size is given by the attribute num_channels.

Example usage:

>>> import torch
>>> from karbonn.utils.size import GroupNormSizeFinder
>>> size_finder = GroupNormSizeFinder()
>>> module = torch.nn.GroupNorm(num_groups=2, num_channels=8)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[8]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[8]

karbonn.utils.size.LinearSizeFinder

Bases: BaseSizeFinder[Module]

Implement a size finder for torch.nn.Linear layer or similar layers.

This module size finder assumes the module has a single input and output, and the input size is given by the attribute in_features and the output size is given by the attribute out_features.

Example usage:

>>> import torch
>>> from karbonn.utils.size import LinearSizeFinder
>>> size_finder = LinearSizeFinder()
>>> module = torch.nn.Linear(4, 6)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[6]

karbonn.utils.size.ModuleListSizeFinder

Bases: BaseSizeFinder[ModuleList]

Implement a size finder for torch.nn.ModuleList layer or similar layers.

Example usage:

>>> import torch
>>> from karbonn.utils.size import ModuleListSizeFinder
>>> size_finder = ModuleListSizeFinder()
>>> module = nn.ModuleList(
...     [nn.Linear(4, 6), nn.ReLU(), nn.LSTM(input_size=4, hidden_size=6)]
... )
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[6]

karbonn.utils.size.MultiheadAttentionSizeFinder

Bases: BaseSizeFinder[Module]

Implement a size finder for torch.nn.MultiheadAttention layer or similar layers.

This module size finder assumes the module has a single input and output, and the input size is given by the attribute embed_dim and the output size is given by the attribute embed_dim.

Example usage:

>>> import torch
>>> from karbonn.utils.size import MultiheadAttentionSizeFinder
>>> size_finder = MultiheadAttentionSizeFinder()
>>> module = torch.nn.MultiheadAttention(embed_dim=4, num_heads=2)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4, 4, 4]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[4]

karbonn.utils.size.RecurrentSizeFinder

Bases: BaseSizeFinder[Module]

Implement a size finder for recurrent layers like torch.nn.RNN, torch.nn.GRU, and torch.nn.LSTM.

This module size finder assumes the module has a single input and output, and the input size is given by the attribute input_size and the output size is given by the attribute hidden_size.

Example usage:

>>> import torch
>>> from karbonn.utils.size import RecurrentSizeFinder
>>> size_finder = RecurrentSizeFinder()
>>> module = torch.nn.RNN(input_size=4, hidden_size=6)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[6]

karbonn.utils.size.SequentialSizeFinder

Bases: BaseSizeFinder[Sequential]

Implement a size finder for torch.nn.Sequential layer.

This module size finder iterates over the child modules until to find one where it can compute the size.

Example usage:

>>> import torch
>>> from karbonn.utils.size import SequentialSizeFinder
>>> size_finder = SequentialSizeFinder()
>>> module = torch.nn.Sequential(
...     torch.nn.Linear(4, 6), torch.nn.ReLU(), torch.nn.Linear(6, 8)
... )
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[8]

karbonn.utils.size.TransformerLayerSizeFinder

Bases: BaseSizeFinder[Module]

Implement a size finder for layers like torch.nn.TransformerEncoderLayer or torch.nn.TransformerDecoderLayer.

This module size finder assumes the module has an attribute self_attn which is used to find the input and output feature sizes.

Example usage:

>>> import torch
>>> from karbonn.utils.size import TransformerLayerSizeFinder
>>> size_finder = TransformerLayerSizeFinder()
>>> module = torch.nn.TransformerEncoderLayer(d_model=4, nhead=1)
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[4]

karbonn.utils.size.TransformerSizeFinder

Bases: BaseSizeFinder[Module]

Implement a size finder for layers like torch.nn.TransformerEncoder or torch.nn.TransformerDecoder.

This module size finder assumes the module has an attribute self_attn which is used to find the input and output feature sizes.

Example usage:

>>> import torch
>>> from karbonn.utils.size import TransformerSizeFinder
>>> size_finder = TransformerSizeFinder()
>>> module = torch.nn.TransformerEncoder(
...     torch.nn.TransformerEncoderLayer(d_model=4, nhead=1),
...     num_layers=1,
... )
>>> in_features = size_finder.find_in_features(module)
>>> in_features
[4]
>>> out_features = size_finder.find_out_features(module)
>>> out_features
[4]

karbonn.utils.size.UnknownSizeFinder

Bases: BaseSizeFinder

Implement a size finder for the modules where the input and output feature sizes are unknown.

Example usage:

>>> import torch
>>> from karbonn.utils.size import UnknownSizeFinder
>>> size_finder = UnknownSizeFinder()
>>> module = torch.nn.ReLU()
>>> in_features = size_finder.find_in_features(module)  # doctest: +SKIP
>>> out_features = size_finder.find_out_features(module)  # doctest: +SKIP

karbonn.utils.size.get_size_finders

get_size_finders() -> dict[type[Module], BaseSizeFinder]

Return the default mappings between the module types and their size finders.

Returns:

Type Description
dict[type[Module], BaseSizeFinder]

The default mappings between the module types and their size finders.

Example usage:

>>> from karbonn.utils.size import get_size_finders
>>> get_size_finders()
{<class 'torch.nn.modules.module.Module'>: UnknownSizeFinder(), ...}

karbonn.utils.size.get_torch_size_finders

get_torch_size_finders() -> (
    dict[type[Module], BaseSizeFinder]
)

Return the default mappings between the module types and their size finders.

Returns:

Type Description
dict[type[Module], BaseSizeFinder]

The default mappings between the module types and their size finders.

Example usage:

>>> from karbonn.utils.size import get_torch_size_finders
>>> get_torch_size_finders()
{<class 'torch.nn.modules.module.Module'>: UnknownSizeFinder(), ...}

karbonn.utils.size.get_karbonn_size_finders

get_karbonn_size_finders() -> (
    dict[type[Module], BaseSizeFinder]
)

Return the default mappings between the module types and their size finders.

Returns:

Type Description
dict[type[Module], BaseSizeFinder]

The default mappings between the module types and their size finders.

Example usage:

>>> from karbonn.utils.size import get_karbonn_size_finders
>>> get_karbonn_size_finders()
{...}