diff --git a/test/nodes/test_batch.py b/test/nodes/test_batch.py index ca518324e..e96713992 100644 --- a/test/nodes/test_batch.py +++ b/test/nodes/test_batch.py @@ -9,6 +9,7 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import TestCase +from torchdata.nodes import IterableWrapper from torchdata.nodes.batch import Batcher, Unbatcher from .utils import MockSource, run_test_save_load_state @@ -28,6 +29,11 @@ def test_batcher(self) -> None: self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") + def test_batcher_batch_size_zero_raises(self): + source = IterableWrapper(range(10)) + with self.assertRaises(ValueError): + Batcher(source, batch_size=0) + def test_batcher_drop_last_false(self) -> None: batch_size = 6 src = MockSource(num_samples=20) diff --git a/torchdata/nodes/batch.py b/torchdata/nodes/batch.py index 7e7ca47de..858021511 100644 --- a/torchdata/nodes/batch.py +++ b/torchdata/nodes/batch.py @@ -25,6 +25,8 @@ class Batcher(BaseNode[List[T]]): def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True): super().__init__() + if batch_size <= 0: + raise ValueError("batch_size must be a positive integer") self.source = source self.batch_size = batch_size self.drop_last = drop_last