Source code for ai_graph.step.foreach

"""ForEach pipeline implementation using Chain of Responsibility pattern."""

from typing import Any, Dict, Iterable, Optional

from tqdm import tqdm

from ..pipeline.base import Pipeline
from .base import BasePipelineStep


[docs] class ForEachStep(BasePipelineStep): """Pipeline step that processes each item in a collection or runs a fixed number of iterations. This step creates a sub-pipeline that processes each item or runs for each iteration. Results are collected in the output data. """
[docs] def __init__( self, items_key: Optional[str] = None, iterations: Optional[int] = None, results_key: str = "foreach_results", name: Optional[str] = None, ): """Initialize a ForEach step. Args: items_key: Key in the input data containing the items to iterate over. If None, uses the iterations parameter instead. iterations: Number of iterations to run if items_key is None. results_key: Key in output data where results will be stored. name: Name of this pipeline step. Raises: ValueError: If neither items_key nor iterations is provided. """ super().__init__(name) if items_key is None and iterations is None: raise ValueError("Either items_key or iterations must be provided") self.items_key = items_key self.iterations = iterations self.results_key = results_key self.sub_pipeline = Pipeline(name=f"{self.name}_SubPipeline")
[docs] def add_sub_step(self, step: BasePipelineStep) -> "ForEachStep": """Add a step to the sub-pipeline. Args: step: Step to add to the sub-pipeline. Returns: Self for method chaining. """ self.sub_pipeline.add_step(step) return self
def _get_items(self, data: Dict[str, Any]) -> Iterable[Any]: """Get items to iterate over from input data or generate range. Args: data: Input data dictionary. Returns: Iterable of items to process or range of iteration counts. """ if self.items_key is not None and self.items_key in data: items = data[self.items_key] if hasattr(items, "__iter__") and not isinstance(items, str): return items # type: ignore[no-any-return] else: return [items] # Wrap single item in list return range(self.iterations or 0) def _process_step(self, data: Dict[str, Any]) -> Dict[str, Any]: """Process each item in the collection or for each iteration. Args: data: Input data to be processed. Returns: Processed data with results from all iterations. """ items = self._get_items(data) results = [] # Calculate total for progress bar if self.items_key is None: total = self.iterations else: try: total = len(items) # type: ignore except TypeError: # If items is not sized, we can't show progress total = None # print tqdm progress bar for the iterations and items for i, item in tqdm( enumerate(items), total=total, desc=f"Processing {self.name}", unit="item", ): # Create iteration context with original data and current item iteration_data = data.copy() iteration_data["_current_item"] = item iteration_data["_iteration_index"] = i # Process the iteration through the sub-pipeline if self.sub_pipeline.steps: result = self.sub_pipeline.process(iteration_data) results.append(result) else: # If no sub-steps, just append the current item results.append(iteration_data) # Store results in the output data data[self.results_key] = results return data