module jaxio#
- class jaxio.Dataset(it: Iterable[Any])[source]#
JAX dataset.
jaxio datasets are just iterators, compatible with Python’s native iter and next builtins, but with handy methods to transform them, very very heavily inspired by tf.data.Dataset.
The vanilla constructor can be thought of as analgous to tf.data.Dataset.from_generator.
Note
JAX datasets are designed assuming the iterators always return pytrees of the same structure. If this is not the case, unexpected behavior might be encountered.
Warning
All datasets created with the jaxio API are NOT jit compatible by default. The user should instead call as_jit_compatible once, and as early as possible in the pipeline, to explicitly control the boundary of the jax io callback.
- Parameters:
it – an arbitrary python iterable, will be converted to a python iterator automatically.
- as_jit_compatible() Dataset[source]#
Enable JIT compatibility.
This is achieved by wrapping the next_fn in a jax io_callback to allow jitting it later.
- Returns:
A new dataset that is jit compatible.
- batch(batch_size: int, axis: int = 0) Dataset[source]#
Yield batches of data of specified batch size.
The new dataset will use a more efficient batching (compatible with jit) if the current dataset is jit compatible.
Note
This drops the last batch if it is not full.
- Parameters:
batch_size – the size of the batches to yield.
axis – the axis to stack the batches along.
- Returns:
A new dataset that yields batches of data.
- enumerate() Dataset[source]#
Yield (index, element) pairs.
Warning
The result will not be jit compatible.
- Returns:
A new dataset that yields (index, element) pairs.
- filter(f: Callable[[Any], bool]) Dataset[source]#
Get a new dataset whose next_fn filters out elements.
Warning
The result will not be jit compatible.
- Parameters:
f – a callable that takes a pytree and returns whether to keep it.
- Returns:
A new dataset that filters out elements.
- fmap(transform: Callable[[Callable[[], Any]], Callable[[], Any]]) Dataset[source]#
Get a new dataset whose next_fn is a transform of the current next_fn.
This probably has no connections to functional programming, don’t hate me🥺
- Parameters:
transform – a callable that takes a next_fn and returns a new next_fn.
- Returns:
A new dataset whose next_fn is a transform of the current next_fn.
- classmethod from_next_fn(next_fn: Callable[[], Any]) Dataset[source]#
Create a dataset that infinitely yields fresh calls to next_fn.
- Parameters:
next_fn – callable that takes no arguments and returns a pytree.
- Returns:
A new dataset yielding the retunred values of fresh calls to next_fn.
- classmethod from_pytree_slices(pytree: Any, axis: int = 0)[source]#
Create a dataset yields the slices of a pytree along a given axis.
This is mostly useful for debugging, as the whole data lives in memory.
- Parameters:
pytree – the pytree whose leaves are to be sliced.
axis – the axis to slice along.
- Returns:
A new dataset yielding the slices of the pytree along the given axis.
- jit(**jit_kwargs) Dataset[source]#
Get a new dataset jitting the next_fn.
Warning
This does NOT pin the computation to the CPU by default. The user should use jax.default_device context managers to do this.
- Parameters:
jit_kwargs – kwargs to pass to jax.jit.
- Returns:
A new dataset whose next_fn is jitted.
- map(f: Callable[[Any], Any]) Dataset[source]#
Get a new dataset applying an element-wise transformation.
- Parameters:
f – a callable that takes a pytree and returns a new pytree.
- Returns:
A new dataset applying f to each element of the current dataset.
- prefetch(bufsize: int = 1) Dataset[source]#
Prefetch elements from the dataset into a queue of given size.
This is achieved by letting a thread pool executor (with a single worker) make calls to the current next_fn and putting the results in a queue.
Warning
The result will not be jit compatible.
- Parameters:
bufsize – the size of the queue to prefetch into.
- Returns:
A new dataset that prefetches elements into a queue.
- repeat(n: int | None = None) Dataset[source]#
Repeat the dataset n times (or infinitely if left unspecified).
Warning
The result will not be jit compatible.
- Parameters:
n – the number of times to repeat the dataset. If None or not specified, the dataset will be repeated infinitely.
- Returns:
A new dataset repeating the current dataset.
- sleep(seconds: int | float) Dataset[source]#
Get a new dataset that sleeps for seconds before yielding an element.
Especially useful for debugging prefetch performance.
Warning
The result will not be jit compatible.
- Parameters:
seconds – the number of seconds to sleep before yielding an element.
- Returns:
A new dataset that sleeps for seconds before yielding an element.