Tensor Operations¶
The batchtensor.tensor module provides functions to manipulate individual PyTorch tensors along
batch and sequence dimensions. These functions are the foundation for the nested operations and can
be used directly when working with single tensors.
Overview¶
The tensor module operates on individual tensors with two primary dimensional conventions:
- Batch dimension: The first dimension (dimension 0) represents the batch
- Sequence dimension: The second dimension (dimension 1) represents the sequence/time steps
All functions in this module follow these conventions. The shape assumptions are:
- Batch operations:
(batch_size, *)where*means any additional dimensions - Sequence operations:
(batch_size, seq_len, *)where*means any additional dimensions
Slicing Operations¶
Slicing Along Batch Dimension¶
Extract a subset of the batch:
>>> import torch
>>> from batchtensor.tensor import slice_along_batch
>>> tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
>>> # Take first 2 items
>>> slice_along_batch(tensor, stop=2)
tensor([[1, 2],
[3, 4]])
>>> # Take items 1 to 3
>>> slice_along_batch(tensor, start=1, stop=3)
tensor([[3, 4],
[5, 6]])
>>> # Take every other item
>>> slice_along_batch(tensor, step=2)
tensor([[1, 2],
[5, 6]])
Slicing Along Sequence Dimension¶
For sequential data:
>>> import torch
>>> from batchtensor.tensor import slice_along_seq
>>> tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
>>> # Take first 2 timesteps
>>> slice_along_seq(tensor, stop=2)
tensor([[1, 2],
[5, 6]])
>>> # Take last 2 timesteps
>>> slice_along_seq(tensor, start=2)
tensor([[3, 4],
[7, 8]])
Selecting Single Items¶
Select a specific index:
>>> import torch
>>> from batchtensor.tensor import select_along_batch, select_along_seq
>>> tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> # Select second batch item (removes batch dimension)
>>> select_along_batch(tensor, index=1)
tensor([4, 5, 6])
>>> # Select first sequence position (removes sequence dimension)
>>> select_along_seq(tensor, index=0)
tensor([1, 4, 7])
Chunking¶
Split tensors into equal chunks:
>>> import torch
>>> from batchtensor.tensor import chunk_along_batch, chunk_along_seq
>>> tensor = torch.tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
>>> chunks = chunk_along_batch(tensor, chunks=3)
>>> len(chunks)
3
>>> chunks[0]
tensor([[0, 1],
[2, 3]])
>>> chunks[1]
tensor([[4, 5],
[6, 7]])
>>> chunks[2]
tensor([[8, 9]])
Splitting¶
Split tensors by specific sizes:
>>> import torch
>>> from batchtensor.tensor import split_along_batch
>>> tensor = torch.tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
>>> # Split into chunks of size 2
>>> splits = split_along_batch(tensor, split_size_or_sections=2)
>>> len(splits)
3
>>> splits[0]
tensor([[0, 1],
[2, 3]])
>>> splits[1]
tensor([[4, 5],
[6, 7]])
>>> splits[2]
tensor([[8, 9]])
>>> # Split with specific sizes
>>> splits = split_along_batch(tensor, split_size_or_sections=[2, 1, 2])
>>> len(splits)
3
>>> splits[0]
tensor([[0, 1],
[2, 3]])
>>> splits[1]
tensor([[4, 5]])
>>> splits[2]
tensor([[6, 7],
[8, 9]])
Indexing Operations¶
Index Selection¶
Select specific indices from a tensor:
>>> import torch
>>> from batchtensor.tensor import index_select_along_batch
>>> tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
>>> indices = torch.tensor([2, 0, 1])
>>> index_select_along_batch(tensor, indices)
tensor([[5, 6],
[1, 2],
[3, 4]])
For sequences:
>>> import torch
>>> from batchtensor.tensor import index_select_along_seq
>>> tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
>>> indices = torch.tensor([3, 0, 2])
>>> index_select_along_seq(tensor, indices)
tensor([[4, 1, 3],
[8, 5, 7]])
Joining Operations¶
Concatenation¶
Combine multiple tensors:
>>> import torch
>>> from batchtensor.tensor import cat_along_batch
>>> tensor1 = torch.tensor([[1, 2]])
>>> tensor2 = torch.tensor([[3, 4]])
>>> tensor3 = torch.tensor([[5, 6]])
>>> cat_along_batch([tensor1, tensor2, tensor3])
tensor([[1, 2],
[3, 4],
[5, 6]])
Concatenate along sequence dimension:
>>> import torch
>>> from batchtensor.tensor import cat_along_seq
>>> tensor1 = torch.tensor([[1], [2]])
>>> tensor2 = torch.tensor([[3], [4]])
>>> cat_along_seq([tensor1, tensor2])
tensor([[1, 3],
[2, 4]])
Repetition¶
Repeat sequences:
>>> import torch
>>> from batchtensor.tensor import repeat_along_seq
>>> tensor = torch.tensor([[1, 2], [3, 4]])
>>> repeat_along_seq(tensor, repeats=3)
tensor([[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4]])
Reduction Operations¶
Sum and Mean¶
Compute sum along dimensions:
>>> import torch
>>> from batchtensor.tensor import sum_along_batch, mean_along_batch
>>> tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> sum_along_batch(tensor)
tensor([ 9., 12.])
>>> mean_along_batch(tensor)
tensor([3., 4.])
>>> # Keep dimensions
>>> sum_along_batch(tensor, keepdim=True)
tensor([[ 9., 12.]])
Along sequence dimension:
>>> import torch
>>> from batchtensor.tensor import sum_along_seq, mean_along_seq
>>> tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> sum_along_seq(tensor)
tensor([ 6., 15.])
>>> mean_along_seq(tensor)
tensor([2., 5.])
Product¶
Compute product along dimensions:
>>> import torch
>>> from batchtensor.tensor import prod_along_batch
>>> tensor = torch.tensor([[2, 3], [4, 5]])
>>> prod_along_batch(tensor)
tensor([ 8, 15])
Min and Max¶
Find minimum and maximum values:
>>> import torch
>>> from batchtensor.tensor import (
... amax_along_batch,
... amin_along_batch,
... max_along_batch,
... min_along_batch,
... )
>>> tensor = torch.tensor([[1, 5], [3, 2], [4, 6]])
>>> amax_along_batch(tensor)
tensor([4, 6])
>>> amin_along_batch(tensor)
tensor([1, 2])
>>> # max and min return both values and indices
>>> max_along_batch(tensor)
torch.return_types.max(
values=tensor([4, 6]),
indices=tensor([2, 2]))
>>> min_along_batch(tensor)
torch.return_types.min(
values=tensor([1, 2]),
indices=tensor([0, 1]))
ArgMax and ArgMin¶
Find indices of extrema:
>>> import torch
>>> from batchtensor.tensor import argmax_along_batch, argmin_along_batch
>>> tensor = torch.tensor([[1, 5], [3, 2], [4, 6]])
>>> argmax_along_batch(tensor)
tensor([2, 2])
>>> argmin_along_batch(tensor)
tensor([0, 1])
Median¶
Compute median values:
>>> import torch
>>> from batchtensor.tensor import median_along_batch
>>> tensor = torch.tensor([[1, 5], [3, 2], [4, 6]])
>>> median_along_batch(tensor)
torch.return_types.median(
values=tensor([3, 5]),
indices=tensor([1, 0]))
Comparison and Sorting¶
Sorting¶
Sort tensors:
>>> import torch
>>> from batchtensor.tensor import sort_along_batch, argsort_along_batch
>>> tensor = torch.tensor([[3, 1], [1, 4], [2, 2]])
>>> sort_along_batch(tensor)
torch.return_types.sort(
values=tensor([[1, 1],
[2, 2],
[3, 4]]),
indices=tensor([[1, 0],
[2, 2],
[0, 1]]))
>>> # Get indices only
>>> argsort_along_batch(tensor)
tensor([[1, 0],
[2, 2],
[0, 1]])
>>> # Sort in descending order
>>> sort_along_batch(tensor, descending=True)
torch.return_types.sort(
values=tensor([[3, 4],
[2, 2],
[1, 1]]),
indices=tensor([[0, 1],
[2, 2],
[1, 0]]))
Mathematical Operations¶
Cumulative Operations¶
Compute cumulative sums and products:
>>> import torch
>>> from batchtensor.tensor import cumsum_along_batch, cumprod_along_batch
>>> tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
>>> cumsum_along_batch(tensor)
tensor([[ 1, 2],
[ 4, 6],
[ 9, 12]])
>>> cumprod_along_batch(tensor)
tensor([[ 1, 2],
[ 3, 8],
[15, 48]])
Along sequence dimension:
>>> import torch
>>> from batchtensor.tensor import cumsum_along_seq
>>> tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> cumsum_along_seq(tensor)
tensor([[ 1, 3, 6],
[ 4, 9, 15]])
Permutation Operations¶
Shuffling¶
Randomly permute batch items:
>>> import torch
>>> from batchtensor.tensor import shuffle_along_batch
>>> tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
>>> # Results will be random
>>> shuffled = shuffle_along_batch(tensor)
For reproducible shuffling, use a generator:
>>> import torch
>>> from batchtensor.tensor import shuffle_along_batch
>>> from batchtensor.utils.seed import get_torch_generator
>>> tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
>>> generator = get_torch_generator(42)
>>> shuffled = shuffle_along_batch(tensor, generator=generator)
Permuting with Specific Order¶
Apply a specific permutation:
>>> import torch
>>> from batchtensor.tensor import permute_along_batch
>>> tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
>>> permutation = torch.tensor([2, 0, 1])
>>> permute_along_batch(tensor, permutation)
tensor([[5, 6],
[1, 2],
[3, 4]])
Best Practices¶
- Dimension Awareness: Always ensure your tensors follow the expected shape conventions (batch dimension first, sequence dimension second)
- Memory Efficiency: Many operations like
sliceandselectreturn views when possible, not copies - Batch Processing: These functions are optimized for batch processing, making them efficient for handling multiple samples simultaneously
- Consistency: Use these functions instead of direct indexing to maintain code clarity and reduce errors
Performance Considerations¶
- Operations leverage PyTorch's optimized tensor operations
- Views are used when possible to avoid data copying
- Batch operations are more efficient than processing items one at a time
- GPU acceleration is available through PyTorch's device management
See Also¶
- Nested Operations - Operations for nested data structures
- API Reference - Complete function reference
- Constants - Dimension constants used throughout the library