Skip to content

BatchDict

redcat.BatchDict

Bases: BaseBatch[dict[Hashable, TBaseBatch]]

Implement a batch object to represent a dictionary of batches.

Parameters:

Name Type Description Default
data dict[Hashable, TBaseBatch]

Specifies the dictionary of batches.

required

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch
BatchDict(
  (key1): tensor([[0, 1, 2, 3, 4],
            [5, 6, 7, 8, 9]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)

redcat.BatchDict.cat_along_seq

cat_along_seq(
    batches: TBaseBatch | Sequence[TBaseBatch],
) -> Self

Concatenates the data of the batches to the current batch along the sequence dimension and creates a new batch.

Note that only the sequences are concatenated.

Parameters:

Name Type Description Default
batches TBaseBatch | Sequence[TBaseBatch]

Specifies the batches to concatenate along the sequence dimension.

required

Returns:

Type Description
Self

A batch with the concatenated data along the sequence dimension.

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch.cat_along_seq(
...     BatchDict({"key1": BatchedTensorSeq(torch.tensor([[10, 11, 12], [20, 21, 22]]))})
... )
BatchDict(
  (key1): tensor([[ 0,  1,  2,  3,  4, 10, 11, 12],
            [ 5,  6,  7,  8,  9, 20, 21, 22]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)

redcat.BatchDict.cat_along_seq_

cat_along_seq_(
    batches: TBaseBatch | Sequence[TBaseBatch],
) -> None

Concatenates the data of the batches to the current batch along the sequence dimension and creates a new batch.

Note that only the sequences are concatenated.

Parameters:

Name Type Description Default
batches TBaseBatch | Sequence[TBaseBatch]

Specifies the batches to concatenate along the sequence dimension.

required

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch.cat_along_seq_(
...     BatchDict({"key1": BatchedTensorSeq(torch.tensor([[10, 11, 12], [20, 21, 22]]))})
... )
>>> batch
BatchDict(
  (key1): tensor([[ 0,  1,  2,  3,  4, 10, 11, 12],
            [ 5,  6,  7,  8,  9, 20, 21, 22]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)

redcat.BatchDict.index_select_along_seq

index_select_along_seq(
    index: Tensor | Sequence[int],
) -> Self

Slices the batch along the sequence dimension at the given indices.

Parameters:

Name Type Description Default
index Tensor | Sequence[int]

Specifies the indices to select.

required

Returns:

Type Description
Self

A new batch sliced along the sequence dimension at the given indices.

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch.index_select_along_seq([2, 4])
BatchDict(
  (key1): tensor([[2, 4], [7, 9]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)
>>> batch.index_select_along_seq(torch.tensor([2, 4]))
BatchDict(
  (key1): tensor([[2, 4], [7, 9]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)
>>> batch.index_select_along_seq(torch.tensor([[2, 4], [4, 3]]))
BatchDict(
  (key1): tensor([[2, 4], [9, 8]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)

redcat.BatchDict.permute_along_seq

permute_along_seq(
    permutation: Sequence[int] | Tensor,
) -> Self

Permutes the data along the sequence dimension.

The same permutation is applied on all the sequences. This method should be called only if all the sequences have the same length.

This method only permutes the values that implement permute_along_seq.

Parameters:

Name Type Description Default
permutation Sequence[int] | Tensor

Specifies the permutation to use on the data. The dimension of the permutation input should be compatible with the shape of the data.

required

Returns:

Type Description
Self

A new batch with permuted data.

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch.permute_along_seq([2, 1, 3, 0, 4])
BatchDict(
  (key1): tensor([[2, 1, 3, 0, 4],
                 [7, 6, 8, 5, 9]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)

redcat.BatchDict.permute_along_seq_

permute_along_seq_(
    permutation: Sequence[int] | Tensor,
) -> None

Permutes the data along the sequence dimension.

The same permutation is applied on all the sequences. This method should be called only if all the sequences have the same length.

This method only permutes the values that implement permute_along_seq.

Parameters:

Name Type Description Default
permutation Sequence[int] | Tensor

Specifies the permutation to use on the data. The dimension of the permutation input should be compatible with the shape of the data.

required

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch.permute_along_seq_([2, 1, 3, 0, 4])
>>> batch
BatchDict(
  (key1): tensor([[2, 1, 3, 0, 4],
                 [7, 6, 8, 5, 9]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)

redcat.BatchDict.repeat_along_seq

repeat_along_seq(repeats: int) -> Self

Repeats the batch along the sequence dimension.

Parameters:

Name Type Description Default
repeats int

Specifies the number of times to repeat the batch along the sequence dimension.

required

Returns:

Type Description
Self

A repeated version of the input batch.

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch.repeat_along_seq(2)
BatchDict(
  (key1): tensor([[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
            [5, 6, 7, 8, 9, 5, 6, 7, 8, 9]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)

redcat.BatchDict.shuffle_along_seq

shuffle_along_seq(
    generator: Generator | None = None,
) -> Self

Shuffles the data along the sequence dimension.

This method should be called only if all the sequences have the same length.

Parameters:

Name Type Description Default
generator Generator | None

Specifies an pseudo random number generator.

None

Returns:

Type Description
Self

A new batch with shuffled data.

Raises:

Type Description
RuntimeError

if the batch has multiple sequence lengths.

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch.shuffle_along_seq()
BatchDict(
  (key1): tensor([[...]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)

redcat.BatchDict.shuffle_along_seq_

shuffle_along_seq_(
    generator: Generator | None = None,
) -> None

Shuffles the data along the sequence dimension.

This method should be called only if all the sequences have the same length.

Parameters:

Name Type Description Default
generator Generator | None

Specifies an pseudo random number generator.

None

Raises:

Type Description
RuntimeError

if the batch has multiple sequence lengths.

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch.shuffle_along_seq()
>>> batch
BatchDict(
  (key1): tensor([[...]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)

redcat.BatchDict.slice_along_seq

slice_along_seq(
    start: int = 0, stop: int | None = None, step: int = 1
) -> Self

Slices the batch in the sequence dimension.

Parameters:

Name Type Description Default
start int

Specifies the index where the slicing of object starts.

0
stop int | None

Specifies the index where the slicing of object stops. None means last.

None
step int

Specifies the increment between each index for slicing.

1

Returns:

Type Description
Self

A slice of the current batch.

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch.slice_along_seq(start=2)
BatchDict(
  (key1): tensor([[2, 3, 4],
            [7, 8, 9]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)
>>> batch.slice_along_seq(stop=3)
BatchDict(
  (key1): tensor([[0, 1, 2],
            [5, 6, 7]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)
>>> batch.slice_along_seq(step=2)
BatchDict(
  (key1): tensor([[0, 2, 4],
            [5, 7, 9]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)

redcat.BatchDict.take_along_seq

take_along_seq(
    indices: TBaseBatch | ndarray | Tensor | Sequence,
) -> Self

Take values along the sequence dimension.

Parameters:

Name Type Description Default
indices TBaseBatch | ndarray | Tensor | Sequence

Specifies the indices to take along the batch dimension.

required

Returns:

Type Description
Self

The batch with the selected data.

Example usage:

>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
...     {
...         "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
...         "key2": BatchList(["a", "b"]),
...     }
... )
>>> batch.take_along_seq(torch.tensor([[3, 0, 1], [2, 3, 4]]))
BatchDict(
  (key1): tensor([[3, 0, 1],
            [7, 8, 9]], batch_dim=0, seq_dim=1)
  (key2): BatchList(data=['a', 'b'])
)