Skip to content

Home

CI Nightly Tests Nightly Package Tests Codecov
Documentation Documentation
Code style: black Doc style: google Ruff Doc style: google
PYPI version Python BSD-3-Clause
Downloads Monthly downloads

Overview

batchtensor is a lightweight library built on top of PyTorch to manipulate nested data structures with PyTorch tensors. This library provides functions for tensors where the first dimension is the batch dimension. It also provides functions for tensors representing a batch of sequences where the first dimension is the batch dimension and the second dimension is the sequence dimension.

Key Features

  • Nested Structure Support: Work with dictionaries, lists, and tuples containing tensors
  • Batch Operations: Efficiently process batches of data along the batch dimension
  • Sequence Operations: Handle sequential/temporal data along the sequence dimension
  • Consistent API: Unified interface for both single tensors and nested structures
  • Type Safety: Fully typed with comprehensive type hints
  • Well Documented: Extensive documentation with examples for all functions
  • Lightweight: Minimal dependencies (PyTorch and coola)
  • Performance: Leverages PyTorch's optimized operations

Main Modules

Motivation

Let's imagine you have a batch which is represented by a dictionary with three tensors, and you want to take the first 2 items. batchtensor provides the function slice_along_batch that allows slicing all the tensors:

>>> import torch
>>> from batchtensor.nested import slice_along_batch
>>> batch = {
...     "a": torch.tensor([[2, 6], [0, 3], [4, 9], [8, 1], [5, 7]]),
...     "b": torch.tensor([4, 3, 2, 1, 0]),
...     "c": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]),
... }
>>> slice_along_batch(batch, stop=2)
{'a': tensor([[2, 6], [0, 3]]), 'b': tensor([4, 3]), 'c': tensor([1., 2.])}

Similarly, it is possible to split a batch into multiple batches by using the function split_along_batch:

>>> import torch
>>> from batchtensor.nested import split_along_batch
>>> batch = {
...     "a": torch.tensor([[2, 6], [0, 3], [4, 9], [8, 1], [5, 7]]),
...     "b": torch.tensor([4, 3, 2, 1, 0]),
...     "c": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]),
... }
>>> split_along_batch(batch, split_size_or_sections=2)
({'a': tensor([[2, 6], [0, 3]]), 'b': tensor([4, 3]), 'c': tensor([1., 2.])},
 {'a': tensor([[4, 9], [8, 1]]), 'b': tensor([2, 1]), 'c': tensor([3., 4.])},
 {'a': tensor([[5, 7]]), 'b': tensor([0]), 'c': tensor([5.])})

Please check the user guide and API reference to see all the implemented functions and detailed examples.

API stability

⚠ While batchtensor is in development stage, no API is guaranteed to be stable from one release to the next. In fact, it is very likely that the API will change multiple times before a stable 1.0.0 release. In practice, this means that upgrading batchtensor to a new version will possibly break any code that was using the old version of batchtensor.

License

batchtensor is licensed under BSD 3-Clause "New" or "Revised" license available in LICENSE file.