BatchDict¶
redcat.BatchDict ¶
Bases: BaseBatch[dict[Hashable, TBaseBatch]]
Implement a batch object to represent a dictionary of batches.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
dict[Hashable, TBaseBatch]
|
Specifies the dictionary of batches. |
required |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch
BatchDict(
(key1): tensor([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
redcat.BatchDict.cat_along_seq ¶
cat_along_seq(
batches: TBaseBatch | Sequence[TBaseBatch],
) -> Self
Concatenates the data of the batches to the current batch along the sequence dimension and creates a new batch.
Note that only the sequences are concatenated.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batches
|
TBaseBatch | Sequence[TBaseBatch]
|
Specifies the batches to concatenate along the sequence dimension. |
required |
Returns:
Type | Description |
---|---|
Self
|
A batch with the concatenated data along the sequence dimension. |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch.cat_along_seq(
... BatchDict({"key1": BatchedTensorSeq(torch.tensor([[10, 11, 12], [20, 21, 22]]))})
... )
BatchDict(
(key1): tensor([[ 0, 1, 2, 3, 4, 10, 11, 12],
[ 5, 6, 7, 8, 9, 20, 21, 22]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
redcat.BatchDict.cat_along_seq_ ¶
cat_along_seq_(
batches: TBaseBatch | Sequence[TBaseBatch],
) -> None
Concatenates the data of the batches to the current batch along the sequence dimension and creates a new batch.
Note that only the sequences are concatenated.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batches
|
TBaseBatch | Sequence[TBaseBatch]
|
Specifies the batches to concatenate along the sequence dimension. |
required |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch.cat_along_seq_(
... BatchDict({"key1": BatchedTensorSeq(torch.tensor([[10, 11, 12], [20, 21, 22]]))})
... )
>>> batch
BatchDict(
(key1): tensor([[ 0, 1, 2, 3, 4, 10, 11, 12],
[ 5, 6, 7, 8, 9, 20, 21, 22]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
redcat.BatchDict.index_select_along_seq ¶
index_select_along_seq(
index: Tensor | Sequence[int],
) -> Self
Slices the batch along the sequence dimension at the given indices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
index
|
Tensor | Sequence[int]
|
Specifies the indices to select. |
required |
Returns:
Type | Description |
---|---|
Self
|
A new batch sliced along the sequence dimension at the given indices. |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch.index_select_along_seq([2, 4])
BatchDict(
(key1): tensor([[2, 4], [7, 9]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
>>> batch.index_select_along_seq(torch.tensor([2, 4]))
BatchDict(
(key1): tensor([[2, 4], [7, 9]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
>>> batch.index_select_along_seq(torch.tensor([[2, 4], [4, 3]]))
BatchDict(
(key1): tensor([[2, 4], [9, 8]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
redcat.BatchDict.permute_along_seq ¶
permute_along_seq(
permutation: Sequence[int] | Tensor,
) -> Self
Permutes the data along the sequence dimension.
The same permutation is applied on all the sequences. This method should be called only if all the sequences have the same length.
This method only permutes the values that implement
permute_along_seq
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
permutation
|
Sequence[int] | Tensor
|
Specifies the permutation to use on the data. The dimension of the permutation input should be compatible with the shape of the data. |
required |
Returns:
Type | Description |
---|---|
Self
|
A new batch with permuted data. |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch.permute_along_seq([2, 1, 3, 0, 4])
BatchDict(
(key1): tensor([[2, 1, 3, 0, 4],
[7, 6, 8, 5, 9]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
redcat.BatchDict.permute_along_seq_ ¶
permute_along_seq_(
permutation: Sequence[int] | Tensor,
) -> None
Permutes the data along the sequence dimension.
The same permutation is applied on all the sequences. This method should be called only if all the sequences have the same length.
This method only permutes the values that implement
permute_along_seq
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
permutation
|
Sequence[int] | Tensor
|
Specifies the permutation to use on the data. The dimension of the permutation input should be compatible with the shape of the data. |
required |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch.permute_along_seq_([2, 1, 3, 0, 4])
>>> batch
BatchDict(
(key1): tensor([[2, 1, 3, 0, 4],
[7, 6, 8, 5, 9]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
redcat.BatchDict.repeat_along_seq ¶
repeat_along_seq(repeats: int) -> Self
Repeats the batch along the sequence dimension.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
repeats
|
int
|
Specifies the number of times to repeat the batch along the sequence dimension. |
required |
Returns:
Type | Description |
---|---|
Self
|
A repeated version of the input batch. |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch.repeat_along_seq(2)
BatchDict(
(key1): tensor([[0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
[5, 6, 7, 8, 9, 5, 6, 7, 8, 9]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
redcat.BatchDict.shuffle_along_seq ¶
shuffle_along_seq(
generator: Generator | None = None,
) -> Self
Shuffles the data along the sequence dimension.
This method should be called only if all the sequences have the same length.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
generator
|
Generator | None
|
Specifies an pseudo random number generator. |
None
|
Returns:
Type | Description |
---|---|
Self
|
A new batch with shuffled data. |
Raises:
Type | Description |
---|---|
RuntimeError
|
if the batch has multiple sequence lengths. |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch.shuffle_along_seq()
BatchDict(
(key1): tensor([[...]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
redcat.BatchDict.shuffle_along_seq_ ¶
shuffle_along_seq_(
generator: Generator | None = None,
) -> None
Shuffles the data along the sequence dimension.
This method should be called only if all the sequences have the same length.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
generator
|
Generator | None
|
Specifies an pseudo random number generator. |
None
|
Raises:
Type | Description |
---|---|
RuntimeError
|
if the batch has multiple sequence lengths. |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch.shuffle_along_seq()
>>> batch
BatchDict(
(key1): tensor([[...]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
redcat.BatchDict.slice_along_seq ¶
slice_along_seq(
start: int = 0, stop: int | None = None, step: int = 1
) -> Self
Slices the batch in the sequence dimension.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
start
|
int
|
Specifies the index where the slicing of object starts. |
0
|
stop
|
int | None
|
Specifies the index where the slicing of object
stops. |
None
|
step
|
int
|
Specifies the increment between each index for slicing. |
1
|
Returns:
Type | Description |
---|---|
Self
|
A slice of the current batch. |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch.slice_along_seq(start=2)
BatchDict(
(key1): tensor([[2, 3, 4],
[7, 8, 9]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
>>> batch.slice_along_seq(stop=3)
BatchDict(
(key1): tensor([[0, 1, 2],
[5, 6, 7]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
>>> batch.slice_along_seq(step=2)
BatchDict(
(key1): tensor([[0, 2, 4],
[5, 7, 9]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)
redcat.BatchDict.take_along_seq ¶
take_along_seq(
indices: TBaseBatch | ndarray | Tensor | Sequence,
) -> Self
Take values along the sequence dimension.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
indices
|
TBaseBatch | ndarray | Tensor | Sequence
|
Specifies the indices to take along the batch dimension. |
required |
Returns:
Type | Description |
---|---|
Self
|
The batch with the selected data. |
Example usage:
>>> import torch
>>> from redcat import BatchDict, BatchList, BatchedTensorSeq
>>> batch = BatchDict(
... {
... "key1": BatchedTensorSeq(torch.arange(10).view(2, 5)),
... "key2": BatchList(["a", "b"]),
... }
... )
>>> batch.take_along_seq(torch.tensor([[3, 0, 1], [2, 3, 4]]))
BatchDict(
(key1): tensor([[3, 0, 1],
[7, 8, 9]], batch_dim=0, seq_dim=1)
(key2): BatchList(data=['a', 'b'])
)